From 5c20f13c48768a80d9e963f2d14565ff16f99563 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sat, 7 Mar 2026 10:46:37 +0100 Subject: [PATCH 01/27] [management] fix domain uniqueness (#5529) --- .../service/manager/expose_tracker_test.go | 8 ++-- .../reverseproxy/service/manager/manager.go | 14 +++---- .../service/manager/manager_test.go | 38 +++++++++---------- .../modules/reverseproxy/service/service.go | 20 +++++----- management/server/store/sql_store.go | 4 +- management/server/store/store.go | 2 +- management/server/store/store_mock.go | 8 ++-- 7 files changed, 46 insertions(+), 48 deletions(-) diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go index bd9f4b93b..c831b4a22 100644 --- a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go @@ -36,11 +36,11 @@ func TestReapExpiredExposes(t *testing.T) { mgr.exposeReaper.reapExpiredExposes(ctx) // Expired service should be deleted - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) require.Error(t, err, "expired service should be deleted") // Non-expired service should remain - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp2.Domain) require.NoError(t, err, "active service should remain") } @@ -191,14 +191,14 @@ func TestReapSkipsRenewedService(t *testing.T) { // Reaper should skip it because the re-check sees a fresh timestamp mgr.exposeReaper.reapExpiredExposes(ctx) - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) require.NoError(t, err, "renewed service should survive reaping") } // expireEphemeralService backdates meta_last_renewed_at to force expiration. func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) { t.Helper() - svc, err := s.GetServiceByDomain(context.Background(), accountID, domain) + svc, err := s.GetServiceByDomain(context.Background(), domain) require.NoError(t, err) expired := time.Now().Add(-2 * exposeTTL) diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index b5e643799..56a1fc98a 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -199,7 +199,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { + if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil { return err } @@ -245,7 +245,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) } - if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { return err } @@ -261,8 +261,8 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee }) } -func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { - existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) +func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error { + existingService, err := transaction.GetServiceByDomain(ctx, domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { return fmt.Errorf("failed to check existing service: %w", err) @@ -271,7 +271,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St } if existingService != nil && existingService.ID != excludeServiceID { - return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain) + return status.Errorf(status.AlreadyExists, "domain already taken") } return nil @@ -352,7 +352,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se } func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error { - if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { + if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil { return err } @@ -805,7 +805,7 @@ func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, // lookupPeerService finds a peer-initiated service by domain and validates ownership. func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) { - svc, err := m.store.GetServiceByDomain(ctx, accountID, domain) + svc, err := m.store.GetServiceByDomain(ctx, domain) if err != nil { return nil, err } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 196eead22..0cb8fa02a 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -72,7 +72,6 @@ func TestInitializeServiceForCreate(t *testing.T) { func TestCheckDomainAvailable(t *testing.T) { ctx := context.Background() - accountID := "test-account" tests := []struct { name string @@ -88,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "available.com"). + GetServiceByDomain(ctx, "available.com"). Return(nil, status.Errorf(status.NotFound, "not found")) }, expectedError: false, @@ -99,7 +98,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). + GetServiceByDomain(ctx, "exists.com"). Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil) }, expectedError: true, @@ -111,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "service-123", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). + GetServiceByDomain(ctx, "exists.com"). Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) }, expectedError: false, @@ -122,7 +121,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "service-456", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "exists.com"). + GetServiceByDomain(ctx, "exists.com"). Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) }, expectedError: true, @@ -134,7 +133,7 @@ func TestCheckDomainAvailable(t *testing.T) { excludeServiceID: "", setupMock: func(ms *store.MockStore) { ms.EXPECT(). - GetServiceByDomain(ctx, accountID, "error.com"). + GetServiceByDomain(ctx, "error.com"). Return(nil, errors.New("database error")) }, expectedError: true, @@ -150,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) { tt.setupMock(mockStore) mgr := &Manager{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID) + err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID) if tt.expectedError { require.Error(t, err) @@ -168,7 +167,6 @@ func TestCheckDomainAvailable(t *testing.T) { func TestCheckDomainAvailable_EdgeCases(t *testing.T) { ctx := context.Background() - accountID := "test-account" t.Run("empty domain", func(t *testing.T) { ctrl := gomock.NewController(t) @@ -176,11 +174,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { mockStore := store.NewMockStore(ctrl) mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, ""). + GetServiceByDomain(ctx, ""). Return(nil, status.Errorf(status.NotFound, "not found")) mgr := &Manager{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "") + err := mgr.checkDomainAvailable(ctx, mockStore, "", "") assert.NoError(t, err) }) @@ -191,11 +189,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { mockStore := store.NewMockStore(ctrl) mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, "test.com"). + GetServiceByDomain(ctx, "test.com"). Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil) mgr := &Manager{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "") + err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "") assert.Error(t, err) sErr, ok := status.FromError(err) @@ -209,11 +207,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { mockStore := store.NewMockStore(ctrl) mockStore.EXPECT(). - GetServiceByDomain(ctx, accountID, "nil.com"). + GetServiceByDomain(ctx, "nil.com"). Return(nil, nil) mgr := &Manager{} - err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "") + err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "") assert.NoError(t, err) }) @@ -241,7 +239,7 @@ func TestPersistNewService(t *testing.T) { // Create another mock for the transaction txMock := store.NewMockStore(ctrl) txMock.EXPECT(). - GetServiceByDomain(ctx, accountID, "new.com"). + GetServiceByDomain(ctx, "new.com"). Return(nil, status.Errorf(status.NotFound, "not found")) txMock.EXPECT(). CreateService(ctx, service). @@ -272,7 +270,7 @@ func TestPersistNewService(t *testing.T) { DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { txMock := store.NewMockStore(ctrl) txMock.EXPECT(). - GetServiceByDomain(ctx, accountID, "existing.com"). + GetServiceByDomain(ctx, "existing.com"). Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil) return fn(txMock) @@ -814,7 +812,7 @@ func TestCreateServiceFromPeer(t *testing.T) { assert.NotEmpty(t, resp.ServiceURL, "service URL should be set") // Verify service is persisted in store - persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain) require.NoError(t, err) assert.Equal(t, resp.Domain, persisted.Domain) assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral") @@ -977,7 +975,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { require.NoError(t, err) // Verify service is deleted - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) require.Error(t, err, "service should be deleted") }) @@ -1012,7 +1010,7 @@ func TestStopServiceFromPeer(t *testing.T) { err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) require.NoError(t, err) - _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + _, err = testStore.GetServiceByDomain(ctx, resp.Domain) require.Error(t, err, "service should be deleted") }) } @@ -1031,7 +1029,7 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(1), count, "one ephemeral service should exist after create") - svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + svc, err := testStore.GetServiceByDomain(ctx, resp.Domain) require.NoError(t, err) err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID) diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index cd9311b44..bfad7fe9a 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -134,7 +134,7 @@ type Service struct { ID string `gorm:"primaryKey"` AccountID string `gorm:"index"` Name string - Domain string `gorm:"index"` + Domain string `gorm:"type:varchar(255);uniqueIndex"` ProxyCluster string `gorm:"index"` Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"` Enabled bool @@ -535,15 +535,15 @@ var hopByHopHeaders = map[string]struct{}{ // reservedHeaders are set authoritatively by the proxy or control HTTP framing // and cannot be overridden. var reservedHeaders = map[string]struct{}{ - "Content-Length": {}, - "Content-Type": {}, - "Cookie": {}, - "Forwarded": {}, - "X-Forwarded-For": {}, - "X-Forwarded-Host": {}, - "X-Forwarded-Port": {}, - "X-Forwarded-Proto": {}, - "X-Real-Ip": {}, + "Content-Length": {}, + "Content-Type": {}, + "Cookie": {}, + "Forwarded": {}, + "X-Forwarded-For": {}, + "X-Forwarded-Host": {}, + "X-Forwarded-Port": {}, + "X-Forwarded-Proto": {}, + "X-Real-Ip": {}, } func validateTargetOptions(idx int, opts *TargetOptions) error { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8f147d915..5997c10e2 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4977,9 +4977,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren return service, nil } -func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) { +func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { var service *rpservice.Service - result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service) + result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain) diff --git a/management/server/store/store.go b/management/server/store/store.go index 5123cde72..1fa99fd05 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -257,7 +257,7 @@ type Store interface { UpdateService(ctx context.Context, service *rpservice.Service) error DeleteService(ctx context.Context, accountID, serviceID string) error GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) - GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) + GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 414872fbb..130df4485 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1932,18 +1932,18 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout } // GetServiceByDomain mocks base method. -func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) { +func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain) + ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) ret0, _ := ret[0].(*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } // GetServiceByDomain indicates an expected call of GetServiceByDomain. -func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain) } // GetServiceByID mocks base method. From 3acd86e34686caaf43748f9d780cb2252c229110 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 9 Mar 2026 10:25:51 +0100 Subject: [PATCH 02/27] [client] "reset connection" error on wake from sleep (#5522) Capture engine reference before actCancel() in cleanupConnection(). After actCancel(), the connectWithRetryRuns goroutine sets engine to nil, causing connectClient.Stop() to skip shutdown. This allows the goroutine to set ErrResetConnection on the shared state after Down() clears it, causing the next Up() to fail. --- client/server/server.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/client/server/server.go b/client/server/server.go index cab94238f..2c7d5abc3 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -849,14 +849,26 @@ func (s *Server) cleanupConnection() error { if s.actCancel == nil { return ErrServiceNotUp } + + // Capture the engine reference before cancelling the context. + // After actCancel(), the connectWithRetryRuns goroutine wakes up + // and sets connectClient.engine = nil, causing connectClient.Stop() + // to skip the engine shutdown entirely. + var engine *internal.Engine + if s.connectClient != nil { + engine = s.connectClient.Engine() + } + s.actCancel() if s.connectClient == nil { return nil } - if err := s.connectClient.Stop(); err != nil { - return err + if engine != nil { + if err := engine.Stop(); err != nil { + return err + } } s.connectClient = nil From 30c02ab78c7bf05d610a636f7a2446144dd0ba01 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 9 Mar 2026 12:23:06 +0100 Subject: [PATCH 03/27] [management] use the cache for the pkce state (#5516) --- .../service/manager/manager_test.go | 15 +++-- management/internals/server/boot.go | 12 +++- management/internals/server/server.go | 1 - .../internals/shared/grpc/pkce_verifier.go | 61 +++++++++++++++++++ management/internals/shared/grpc/proxy.go | 59 +++--------------- .../internals/shared/grpc/proxy_test.go | 55 +++++++++++++---- .../shared/grpc/validate_session_test.go | 5 +- management/server/account_test.go | 2 +- .../proxy/auth_callback_integration_test.go | 4 ++ .../testing/testing_tools/channel/channel.go | 6 +- proxy/management_integration_test.go | 4 ++ 11 files changed, 152 insertions(+), 72 deletions(-) create mode 100644 management/internals/shared/grpc/pkce_verifier.go diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 0cb8fa02a..ba4e1c805 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -423,8 +423,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { t.Helper() tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100) require.NoError(t, err) - srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) - t.Cleanup(srv.Close) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) return srv } @@ -703,8 +704,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) require.NoError(t, err) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) - t.Cleanup(proxySrv.Close) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) @@ -1134,8 +1136,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) { tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) require.NoError(t, err) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) - t.Cleanup(proxySrv.Close) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 2049f0051..eb13a15e3 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) s.AfterInit(func(s *BaseServer) { proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetProxyController(s.ServiceProxyController()) @@ -203,6 +203,16 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { }) } +func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore { + return Create(s, func() *nbgrpc.PKCEVerifierStore { + pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) + if err != nil { + log.Fatalf("failed to create PKCE verifier store: %v", err) + } + return pkceStore + }) +} + func (s *BaseServer) AccessLogsManager() accesslogs.Manager { return Create(s, func() accesslogs.Manager { accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager()) diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 573983a79..9b8716da1 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -248,7 +248,6 @@ func (s *BaseServer) Stop() error { _ = s.certManager.Listener().Close() } s.GRPCServer().Stop() - s.ReverseProxyGRPCServer().Close() if s.proxyAuthClose != nil { s.proxyAuthClose() s.proxyAuthClose = nil diff --git a/management/internals/shared/grpc/pkce_verifier.go b/management/internals/shared/grpc/pkce_verifier.go new file mode 100644 index 000000000..441e8b051 --- /dev/null +++ b/management/internals/shared/grpc/pkce_verifier.go @@ -0,0 +1,61 @@ +package grpc + +import ( + "context" + "fmt" + "time" + + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" + log "github.com/sirupsen/logrus" + + nbcache "github.com/netbirdio/netbird/management/server/cache" +) + +// PKCEVerifierStore manages PKCE verifiers for OAuth flows. +// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var. +type PKCEVerifierStore struct { + cache *cache.Cache[string] + ctx context.Context +} + +// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection +func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) { + cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) + if err != nil { + return nil, fmt.Errorf("failed to create cache store: %w", err) + } + + return &PKCEVerifierStore{ + cache: cache.New[string](cacheStore), + ctx: ctx, + }, nil +} + +// Store saves a PKCE verifier associated with an OAuth state parameter. +// The verifier is stored with the specified TTL and will be automatically deleted after expiration. +func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error { + if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil { + return fmt.Errorf("failed to store PKCE verifier: %w", err) + } + + log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl) + return nil +} + +// LoadAndDelete retrieves and removes a PKCE verifier for the given state. +// Returns the verifier and true if found, or empty string and false if not found. +// This enforces single-use semantics for PKCE verifiers. +func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) { + verifier, err := s.cache.Get(s.ctx, state) + if err != nil { + log.Debugf("PKCE verifier not found for state") + return "", false + } + + if err := s.cache.Delete(s.ctx, state); err != nil { + log.Warnf("Failed to delete PKCE verifier for state: %v", err) + } + + return verifier, true +} diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 308da5e2f..e2d0f1abe 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -82,20 +82,12 @@ type ProxyServiceServer struct { // OIDC configuration for proxy authentication oidcConfig ProxyOIDCConfig - // TODO: use database to store these instead? - // pkceVerifiers stores PKCE code verifiers keyed by OAuth state. - // Entries expire after pkceVerifierTTL to prevent unbounded growth. - pkceVerifiers sync.Map - pkceCleanupCancel context.CancelFunc + // Store for PKCE verifiers + pkceVerifierStore *PKCEVerifierStore } const pkceVerifierTTL = 10 * time.Minute -type pkceEntry struct { - verifier string - createdAt time.Time -} - // proxyConnection represents a connected proxy type proxyConnection struct { proxyID string @@ -107,42 +99,21 @@ type proxyConnection struct { } // NewProxyServiceServer creates a new proxy service server. -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { - ctx, cancel := context.WithCancel(context.Background()) +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { + ctx := context.Background() s := &ProxyServiceServer{ accessLogManager: accessLogMgr, oidcConfig: oidcConfig, tokenStore: tokenStore, + pkceVerifierStore: pkceStore, peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, - pkceCleanupCancel: cancel, } - go s.cleanupPKCEVerifiers(ctx) go s.cleanupStaleProxies(ctx) return s } -// cleanupPKCEVerifiers periodically removes expired PKCE verifiers. -func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) { - ticker := time.NewTicker(pkceVerifierTTL) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - now := time.Now() - s.pkceVerifiers.Range(func(key, value any) bool { - if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL { - s.pkceVerifiers.Delete(key) - } - return true - }) - } - } -} - // cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { ticker := time.NewTicker(5 * time.Minute) @@ -159,11 +130,6 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { } } -// Close stops background goroutines. -func (s *ProxyServiceServer) Close() { - s.pkceCleanupCancel() -} - func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { s.serviceManager = manager } @@ -790,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum) codeVerifier := oauth2.GenerateVerifier() - s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()}) + if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil { + log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err) + return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err) + } return &proto.GetOIDCURLResponse{ Url: (&oauth2.Config{ @@ -827,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string { // ValidateState validates the state parameter from an OAuth callback. // Returns the original redirect URL if valid, or an error if invalid. func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) { - v, ok := s.pkceVerifiers.LoadAndDelete(state) + verifier, ok := s.pkceVerifierStore.LoadAndDelete(state) if !ok { return "", "", errors.New("no verifier for state") } - entry, ok := v.(pkceEntry) - if !ok { - return "", "", errors.New("invalid verifier for state") - } - if time.Since(entry.createdAt) > pkceVerifierTTL { - return "", "", errors.New("PKCE verifier expired") - } - verifier = entry.verifier // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) parts := strings.Split(state, "|") diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index ddeadac5a..b7abb28b6 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -5,11 +5,10 @@ import ( "crypto/rand" "encoding/base64" "strings" + "sync" "testing" "time" - "sync" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -94,11 +93,16 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda } func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + ctx := context.Background() + tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } s.SetProxyController(newTestProxyController()) @@ -151,11 +155,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { } func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + ctx := context.Background() + tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } s.SetProxyController(newTestProxyController()) @@ -185,11 +194,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { } func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + ctx := context.Background() + tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } s.SetProxyController(newTestProxyController()) @@ -241,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string { } func TestOAuthState_NeverTheSame(t *testing.T) { + ctx := context.Background() + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } redirectURL := "https://app.example.com/callback" @@ -265,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) { } func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { + ctx := context.Background() + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } // Old format had only 2 parts: base64(url)|hmac - s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) + err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) + require.NoError(t, err) - _, _, err := s.ValidateState("base64url|hmac") + _, _, err = s.ValidateState("base64url|hmac") require.Error(t, err) assert.Contains(t, err.Error(), "invalid state format") } func TestValidateState_RejectsInvalidHMAC(t *testing.T) { + ctx := context.Background() + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } // Store with tampered HMAC - s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) + err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) + require.NoError(t, err) - _, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac") + _, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac") require.Error(t, err) assert.Contains(t, err.Error(), "invalid state signature") } diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 124ddf620..647e8443b 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -41,7 +41,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) require.NoError(t, err) - proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) proxyService.SetServiceManager(serviceManager) createTestProxies(t, ctx, testStore) diff --git a/management/server/account_test.go b/management/server/account_test.go index a073d4fca..fdec43617 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3133,7 +3133,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) if err != nil { return nil, nil, err diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index c7fd08da8..3bed54e80 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -193,6 +193,9 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) require.NoError(t, err) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + usersManager := users.NewManager(testStore) oidcConfig := nbgrpc.ProxyOIDCConfig{ @@ -206,6 +209,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { proxyService := nbgrpc.NewProxyServiceServer( &testAccessLogManager{}, tokenStore, + pkceStore, oidcConfig, nil, usersManager, diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 1d74f88d5..5e33ad652 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -98,12 +98,16 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy token store: %v", err) } + pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + if err != nil { + t.Fatalf("Failed to create PKCE verifier store: %v", err) + } noopMeter := noop.NewMeterProvider().Meter("") proxyMgr, err := proxymanager.NewManager(store, noopMeter) if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) domainManager := manager.NewManager(store, proxyMgr, permissionsManager) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 3e5a21400..6a0ecce30 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -116,6 +116,9 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) require.NoError(t, err) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + // Create real users manager usersManager := users.NewManager(testStore) @@ -131,6 +134,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { proxyService := nbgrpc.NewProxyServiceServer( &testAccessLogManager{}, tokenStore, + pkceStore, oidcConfig, nil, usersManager, From 11eb725ac8c56fb2430f0b940903ccba9c15c781 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 9 Mar 2026 14:56:46 +0100 Subject: [PATCH 04/27] [management] only count login request duration for successful logins (#5545) --- management/internals/shared/grpc/server.go | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index a07cafe90..6e8358f02 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -330,13 +330,12 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) - } - unlock() unlock = nil + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) + } log.WithContext(ctx).Debugf("Sync took %s", time.Since(reqStart)) s.syncSem.Add(-1) @@ -743,13 +742,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) - defer func() { - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) - } - log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart)) - }() - if loginReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) @@ -799,6 +791,11 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto return nil, status.Errorf(codes.Internal, "failed logging in peer") } + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) + } + log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart)) + return &proto.EncryptedMessage{ WgPubKey: key.PublicKey().String(), Body: encryptedResp, From 15aa6bae1b393459fe2300b815be1fe69331b20e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 9 Mar 2026 18:39:11 +0100 Subject: [PATCH 05/27] [client] Fix exit node menu not refreshing on Windows (#5553) * [client] Fix exit node menu not refreshing on Windows TrayOpenedCh is not implemented in the systray library on Windows, so exit nodes were never refreshed after the initial connect. Combined with the management sync not having populated routes yet when the Connected status fires, this caused the exit node menu to remain empty permanently after disconnect/reconnect cycles. Add a background poller on Windows that refreshes exit nodes while connected, with fast initial polling to catch routes from management sync followed by a steady 10s interval. On macOS/Linux, TrayOpenedCh continues to handle refreshes on each tray open. Also fix a data race on connectClient assignment in the server's connect() method and add nil checks in CleanState/DeleteState to prevent panics when connectClient is nil. * Remove unused exitNodeIDs * Remove unused exitNodeState struct --- client/server/server.go | 11 +- client/server/server_connect_test.go | 187 +++++++++++++++++++++++++++ client/server/state.go | 4 +- client/ui/client_ui.go | 5 +- client/ui/event_handler.go | 3 +- client/ui/network.go | 95 +++++++++----- 6 files changed, 267 insertions(+), 38 deletions(-) create mode 100644 client/server/server_connect_test.go diff --git a/client/server/server.go b/client/server/server.go index 2c7d5abc3..69d79d9cd 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1625,9 +1625,14 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error { log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) - s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) - if err := s.connectClient.Run(runningChan, s.logFile); err != nil { + client := internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) + client.SetSyncResponsePersistence(s.persistSyncResponse) + + s.mutex.Lock() + s.connectClient = client + s.mutex.Unlock() + + if err := client.Run(runningChan, s.logFile); err != nil { return err } return nil diff --git a/client/server/server_connect_test.go b/client/server/server_connect_test.go new file mode 100644 index 000000000..8d31c2ae6 --- /dev/null +++ b/client/server/server_connect_test.go @@ -0,0 +1,187 @@ +package server + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/proto" +) + +func newTestServer() *Server { + return &Server{ + rootCtx: context.Background(), + statusRecorder: peer.NewRecorder(""), + } +} + +func newDummyConnectClient(ctx context.Context) *internal.ConnectClient { + return internal.NewConnectClient(ctx, nil, nil, false) +} + +// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient +// under mutex protection so concurrent readers see a consistent value. +func TestConnectSetsClientWithMutex(t *testing.T) { + s := newTestServer() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Manually simulate what connect() does (without calling Run which panics without full setup) + client := newDummyConnectClient(ctx) + + s.mutex.Lock() + s.connectClient = client + s.mutex.Unlock() + + // Verify the assignment is visible under mutex + s.mutex.Lock() + assert.Equal(t, client, s.connectClient, "connectClient should be set") + s.mutex.Unlock() +} + +// TestConcurrentConnectClientAccess validates that concurrent reads of +// s.connectClient under mutex don't race with a write. +func TestConcurrentConnectClientAccess(t *testing.T) { + s := newTestServer() + ctx := context.Background() + client := newDummyConnectClient(ctx) + + var wg sync.WaitGroup + nilCount := 0 + setCount := 0 + var mu sync.Mutex + + // Start readers + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + s.mutex.Lock() + c := s.connectClient + s.mutex.Unlock() + + mu.Lock() + defer mu.Unlock() + if c == nil { + nilCount++ + } else { + setCount++ + } + }() + } + + // Simulate connect() writing under mutex + time.Sleep(5 * time.Millisecond) + s.mutex.Lock() + s.connectClient = client + s.mutex.Unlock() + + wg.Wait() + + assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic") +} + +// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection +// properly nils out connectClient. +func TestCleanupConnection_ClearsConnectClient(t *testing.T) { + s := newTestServer() + _, cancel := context.WithCancel(context.Background()) + s.actCancel = cancel + + s.connectClient = newDummyConnectClient(context.Background()) + s.clientRunning = true + + err := s.cleanupConnection() + require.NoError(t, err) + + assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup") +} + +// TestCleanState_NilConnectClient validates that CleanState doesn't panic +// when connectClient is nil. +func TestCleanState_NilConnectClient(t *testing.T) { + s := newTestServer() + s.connectClient = nil + s.profileManager = nil // will cause error if it tries to proceed past the nil check + + // Should not panic — the nil check should prevent calling Status() on nil + assert.NotPanics(t, func() { + _, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true}) + }) +} + +// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic +// when connectClient is nil. +func TestDeleteState_NilConnectClient(t *testing.T) { + s := newTestServer() + s.connectClient = nil + s.profileManager = nil + + assert.NotPanics(t, func() { + _, _ = s.DeleteState(context.Background(), &proto.DeleteStateRequest{All: true}) + }) +} + +// TestDownThenUp_StaleRunningChan documents the known state issue where +// clientRunningChan from a previous connection is already closed, causing +// waitForUp() to return immediately on reconnect. +func TestDownThenUp_StaleRunningChan(t *testing.T) { + s := newTestServer() + + // Simulate state after a successful connection + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + close(s.clientRunningChan) // closed when engine started + s.clientGiveUpChan = make(chan struct{}) + s.connectClient = newDummyConnectClient(context.Background()) + + _, cancel := context.WithCancel(context.Background()) + s.actCancel = cancel + + // Simulate Down(): cleanupConnection sets connectClient = nil + s.mutex.Lock() + err := s.cleanupConnection() + s.mutex.Unlock() + require.NoError(t, err) + + // After cleanup: connectClient is nil, clientRunning still true + // (goroutine hasn't exited yet) + s.mutex.Lock() + assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup") + assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits") + s.mutex.Unlock() + + // waitForUp() returns immediately due to stale closed clientRunningChan + ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer ctxCancel() + + waitDone := make(chan error, 1) + go func() { + _, err := s.waitForUp(ctx) + waitDone <- err + }() + + select { + case err := <-waitDone: + assert.NoError(t, err, "waitForUp returns success on stale channel") + // But connectClient is still nil — this is the stale state issue + s.mutex.Lock() + assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success") + s.mutex.Unlock() + case <-time.After(1 * time.Second): + t.Fatal("waitForUp should have returned immediately due to stale closed channel") + } +} + +// TestConnectClient_EngineNilOnFreshClient validates that a newly created +// ConnectClient has nil Engine (before Run is called). +func TestConnectClient_EngineNilOnFreshClient(t *testing.T) { + client := newDummyConnectClient(context.Background()) + assert.Nil(t, client.Engine(), "engine should be nil on fresh ConnectClient") +} diff --git a/client/server/state.go b/client/server/state.go index 1cf85cd37..8dca6bde1 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -39,7 +39,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro // CleanState handles cleaning of states (performing cleanup operations) func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) { - if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { + if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) { return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") } @@ -82,7 +82,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) ( // DeleteState handles deletion of states without cleanup func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) { - if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { + if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) { return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 0290e17d5..7af00cd20 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -323,7 +323,7 @@ type serviceClient struct { exitNodeMu sync.Mutex mExitNodeItems []menuHandler - exitNodeStates []exitNodeState + exitNodeRetryCancel context.CancelFunc mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window @@ -924,7 +924,7 @@ func (s *serviceClient) updateStatus() error { s.mDown.Enable() s.mNetworks.Enable() s.mExitNode.Enable() - go s.updateExitNodes() + s.startExitNodeRefresh() systrayIconState = true case status.Status == string(internal.StatusConnecting): s.setConnectingStatus() @@ -985,6 +985,7 @@ func (s *serviceClient) setDisconnectedStatus() { s.mUp.Enable() s.mNetworks.Disable() s.mExitNode.Disable() + s.cancelExitNodeRetry() go s.updateExitNodes() } diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 2216c8aeb..6adf8778c 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -100,8 +100,7 @@ func (h *eventHandler) handleConnectClick() { func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() - - h.client.exitNodeStates = []exitNodeState{} + h.client.cancelExitNodeRetry() if h.client.connectCancel != nil { log.Debugf("cancelling ongoing connect operation") diff --git a/client/ui/network.go b/client/ui/network.go index 9a5ad7662..ed03f5ada 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "runtime" - "slices" "sort" "strings" "time" @@ -34,11 +33,6 @@ const ( type filter string -type exitNodeState struct { - id string - selected bool -} - func (s *serviceClient) showNetworksUI() { s.wNetworks = s.app.NewWindow("Networks") s.wNetworks.SetOnClosed(s.cancel) @@ -335,16 +329,75 @@ func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, s.updateNetworks(grid, f) } -func (s *serviceClient) updateExitNodes() { +// startExitNodeRefresh initiates exit node menu refresh after connecting. +// On Windows, TrayOpenedCh is not supported by the systray library, so we use +// a background poller to keep exit nodes in sync while connected. +// On macOS/Linux, TrayOpenedCh handles refreshes on each tray open. +func (s *serviceClient) startExitNodeRefresh() { + s.cancelExitNodeRetry() + + if runtime.GOOS == "windows" { + ctx, cancel := context.WithCancel(s.ctx) + s.exitNodeMu.Lock() + s.exitNodeRetryCancel = cancel + s.exitNodeMu.Unlock() + + go s.pollExitNodes(ctx) + } else { + go s.updateExitNodes() + } +} + +func (s *serviceClient) cancelExitNodeRetry() { + s.exitNodeMu.Lock() + if s.exitNodeRetryCancel != nil { + s.exitNodeRetryCancel() + s.exitNodeRetryCancel = nil + } + s.exitNodeMu.Unlock() +} + +// pollExitNodes periodically refreshes exit nodes while connected. +// Uses a short initial interval to catch routes from the management sync, +// then switches to a longer interval for ongoing updates. +func (s *serviceClient) pollExitNodes(ctx context.Context) { + // Initial fast polling to catch routes as they appear after connect. + for i := 0; i < 5; i++ { + if s.updateExitNodes() { + break + } + select { + case <-ctx.Done(): + return + case <-time.After(2 * time.Second): + } + } + + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.updateExitNodes() + } + } +} + +// updateExitNodes fetches exit nodes from the daemon and recreates the menu. +// Returns true if exit nodes were found. +func (s *serviceClient) updateExitNodes() bool { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf("get client: %v", err) - return + return false } exitNodes, err := s.getExitNodes(conn) if err != nil { log.Errorf("get exit nodes: %v", err) - return + return false } s.exitNodeMu.Lock() @@ -354,28 +407,14 @@ func (s *serviceClient) updateExitNodes() { if len(s.mExitNodeItems) > 0 { s.mExitNode.Enable() - } else { - s.mExitNode.Disable() + return true } + + s.mExitNode.Disable() + return false } func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { - var exitNodeIDs []exitNodeState - for _, node := range exitNodes { - exitNodeIDs = append(exitNodeIDs, exitNodeState{ - id: node.ID, - selected: node.Selected, - }) - } - - sort.Slice(exitNodeIDs, func(i, j int) bool { - return exitNodeIDs[i].id < exitNodeIDs[j].id - }) - if slices.Equal(s.exitNodeStates, exitNodeIDs) { - log.Debug("Exit node menu already up to date") - return - } - for _, node := range s.mExitNodeItems { node.cancel() node.Hide() @@ -413,8 +452,6 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { go s.handleChecked(ctx, node.ID, menuItem) } - s.exitNodeStates = exitNodeIDs - if showDeselectAll { s.mExitNode.AddSeparator() deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All") From f88429982306156f9edecb3ae8046135c9f0398a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 9 Mar 2026 18:45:45 +0100 Subject: [PATCH 06/27] [proxy] refactor metrics and add usage logs (#5533) * **New Features** * Access logs now include bytes_upload and bytes_download (API and schemas updated, fields required). * Certificate issuance duration is now recorded as a metric. * **Refactor** * Metrics switched from Prometheus client to OpenTelemetry-backed meters; health endpoint now exposes OpenMetrics via OTLP exporter. * **Tests** * Metric tests updated to use OpenTelemetry Prometheus exporter and MeterProvider. --- .../reverseproxy/accesslogs/accesslogentry.go | 6 + proxy/internal/accesslog/logger.go | 124 ++- proxy/internal/accesslog/middleware.go | 16 + proxy/internal/accesslog/statuswriter.go | 25 +- proxy/internal/acme/manager.go | 12 +- proxy/internal/acme/manager_test.go | 4 +- proxy/internal/metrics/metrics.go | 216 +++-- proxy/internal/metrics/metrics_test.go | 20 +- proxy/server.go | 32 +- shared/management/http/api/openapi.yml | 12 + shared/management/http/api/types.gen.go | 814 +++++++++++++++++- shared/management/proto/management.pb.go | 2 +- shared/management/proto/proxy_service.pb.go | 309 +++---- shared/management/proto/proxy_service.proto | 2 + 14 files changed, 1343 insertions(+), 251 deletions(-) diff --git a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go index 019cb634a..0bcc59b68 100644 --- a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go +++ b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go @@ -24,6 +24,8 @@ type AccessLogEntry struct { Reason string UserId string `gorm:"index"` AuthMethodUsed string `gorm:"index"` + BytesUpload int64 `gorm:"index"` + BytesDownload int64 `gorm:"index"` } // FromProto creates an AccessLogEntry from a proto.AccessLog @@ -39,6 +41,8 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) { a.UserId = serviceLog.GetUserId() a.AuthMethodUsed = serviceLog.GetAuthMechanism() a.AccountID = serviceLog.GetAccountId() + a.BytesUpload = serviceLog.GetBytesUpload() + a.BytesDownload = serviceLog.GetBytesDownload() if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { if ip, err := netip.ParseAddr(sourceIP); err == nil { @@ -101,5 +105,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { AuthMethodUsed: authMethod, CountryCode: countryCode, CityName: cityName, + BytesUpload: a.BytesUpload, + BytesDownload: a.BytesDownload, } } diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 9e204be65..4ba5a7755 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -3,6 +3,7 @@ package accesslog import ( "context" "net/netip" + "sync" "time" log "github.com/sirupsen/logrus" @@ -13,6 +14,23 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +const ( + requestThreshold = 10000 // Log every 10k requests + bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB + usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour + usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours +) + +type domainUsage struct { + requestCount int64 + requestStartTime time.Time + + bytesTransferred int64 + bytesStartTime time.Time + + lastActivity time.Time // Track last activity for cleanup +} + type gRPCClient interface { SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error) } @@ -22,6 +40,11 @@ type Logger struct { client gRPCClient logger *log.Logger trustedProxies []netip.Prefix + + usageMux sync.Mutex + domainUsage map[string]*domainUsage + + cleanupCancel context.CancelFunc } // NewLogger creates a new access log Logger. The trustedProxies parameter @@ -31,10 +54,26 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre if logger == nil { logger = log.StandardLogger() } - return &Logger{ + + ctx, cancel := context.WithCancel(context.Background()) + l := &Logger{ client: client, logger: logger, trustedProxies: trustedProxies, + domainUsage: make(map[string]*domainUsage), + cleanupCancel: cancel, + } + + // Start background cleanup routine + go l.cleanupStaleUsage(ctx) + + return l +} + +// Close stops the cleanup routine. Should be called during graceful shutdown. +func (l *Logger) Close() { + if l.cleanupCancel != nil { + l.cleanupCancel() } } @@ -51,6 +90,8 @@ type logEntry struct { AuthMechanism string UserId string AuthSuccess bool + BytesUpload int64 + BytesDownload int64 } func (l *Logger) log(ctx context.Context, entry logEntry) { @@ -84,6 +125,8 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { AuthMechanism: entry.AuthMechanism, UserId: entry.UserId, AuthSuccess: entry.AuthSuccess, + BytesUpload: entry.BytesUpload, + BytesDownload: entry.BytesDownload, }, }); err != nil { // If it fails to send on the gRPC connection, then at least log it to the error log. @@ -103,3 +146,82 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { } }() } + +// trackUsage records request and byte counts per domain, logging when thresholds are hit. +func (l *Logger) trackUsage(domain string, bytesTransferred int64) { + if domain == "" { + return + } + + l.usageMux.Lock() + defer l.usageMux.Unlock() + + now := time.Now() + usage, exists := l.domainUsage[domain] + if !exists { + usage = &domainUsage{ + requestStartTime: now, + bytesStartTime: now, + lastActivity: now, + } + l.domainUsage[domain] = usage + } + + usage.lastActivity = now + + usage.requestCount++ + if usage.requestCount >= requestThreshold { + elapsed := time.Since(usage.requestStartTime) + l.logger.WithFields(log.Fields{ + "domain": domain, + "requests": usage.requestCount, + "duration": elapsed.String(), + }).Infof("domain %s had %d requests over %s", domain, usage.requestCount, elapsed) + + usage.requestCount = 0 + usage.requestStartTime = now + } + + usage.bytesTransferred += bytesTransferred + if usage.bytesTransferred >= bytesThreshold { + elapsed := time.Since(usage.bytesStartTime) + bytesInGB := float64(usage.bytesTransferred) / (1024 * 1024 * 1024) + l.logger.WithFields(log.Fields{ + "domain": domain, + "bytes": usage.bytesTransferred, + "bytes_gb": bytesInGB, + "duration": elapsed.String(), + }).Infof("domain %s transferred %.2f GB over %s", domain, bytesInGB, elapsed) + + usage.bytesTransferred = 0 + usage.bytesStartTime = now + } +} + +// cleanupStaleUsage removes usage entries for domains that have been inactive. +func (l *Logger) cleanupStaleUsage(ctx context.Context) { + ticker := time.NewTicker(usageCleanupPeriod) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + l.usageMux.Lock() + now := time.Now() + removed := 0 + for domain, usage := range l.domainUsage { + if now.Sub(usage.lastActivity) > usageInactiveWindow { + delete(l.domainUsage, domain) + removed++ + } + } + l.usageMux.Unlock() + + if removed > 0 { + l.logger.Debugf("cleaned up %d stale domain usage entries", removed) + } + } + } +} diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index dd4798975..7368185c0 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -32,6 +32,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { status: http.StatusOK, } + var bytesRead int64 + if r.Body != nil { + r.Body = &bodyCounter{ + ReadCloser: r.Body, + bytesRead: &bytesRead, + } + } + // Resolve the source IP using trusted proxy configuration before passing // the request on, as the proxy will modify forwarding headers. sourceIp := extractSourceIP(r, l.trustedProxies) @@ -53,6 +61,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { host = r.Host } + bytesUpload := bytesRead + bytesDownload := sw.bytesWritten + entry := logEntry{ ID: requestID, ServiceId: capturedData.GetServiceId(), @@ -66,10 +77,15 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { AuthMechanism: capturedData.GetAuthMethod(), UserId: capturedData.GetUserID(), AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, + BytesUpload: bytesUpload, + BytesDownload: bytesDownload, } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId()) l.log(r.Context(), entry) + + // Track usage for cost monitoring (upload + download) by domain + l.trackUsage(host, bytesUpload+bytesDownload) }) } diff --git a/proxy/internal/accesslog/statuswriter.go b/proxy/internal/accesslog/statuswriter.go index 43cda59f9..24f7b35e9 100644 --- a/proxy/internal/accesslog/statuswriter.go +++ b/proxy/internal/accesslog/statuswriter.go @@ -1,18 +1,39 @@ package accesslog import ( + "io" + "github.com/netbirdio/netbird/proxy/internal/responsewriter" ) -// statusWriter captures the HTTP status code from WriteHeader calls. +// statusWriter captures the HTTP status code and bytes written from responses. // It embeds responsewriter.PassthroughWriter which handles all the optional // interfaces (Hijacker, Flusher, Pusher) automatically. type statusWriter struct { *responsewriter.PassthroughWriter - status int + status int + bytesWritten int64 } func (w *statusWriter) WriteHeader(status int) { w.status = status w.PassthroughWriter.WriteHeader(status) } + +func (w *statusWriter) Write(b []byte) (int, error) { + n, err := w.PassthroughWriter.Write(b) + w.bytesWritten += int64(n) + return n, err +} + +// bodyCounter wraps an io.ReadCloser and counts bytes read from the request body. +type bodyCounter struct { + io.ReadCloser + bytesRead *int64 +} + +func (bc *bodyCounter) Read(p []byte) (int, error) { + n, err := bc.ReadCloser.Read(p) + *bc.bytesRead += int64(n) + return n, err +} diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index d491d65a3..ebc15314b 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -42,6 +42,10 @@ type domainInfo struct { err string } +type metricsRecorder interface { + RecordCertificateIssuance(duration time.Duration) +} + // Manager wraps autocert.Manager with domain tracking and cross-replica // coordination via a pluggable locking strategy. The locker prevents // duplicate ACME requests when multiple replicas share a certificate cache. @@ -55,6 +59,7 @@ type Manager struct { certNotifier certificateNotifier logger *log.Logger + metrics metricsRecorder } // NewManager creates a new ACME certificate manager. The certDir is used @@ -63,7 +68,7 @@ type Manager struct { // eabKID and eabHMACKey are optional External Account Binding credentials // required for some CAs like ZeroSSL. The eabHMACKey should be the base64 // URL-encoded string provided by the CA. -func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager { +func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod, metrics metricsRecorder) *Manager { if logger == nil { logger = log.StandardLogger() } @@ -73,6 +78,7 @@ func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificat domains: make(map[domain.Domain]*domainInfo), certNotifier: notifier, logger: logger, + metrics: metrics, } var eab *acme.ExternalAccountBinding @@ -161,6 +167,10 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { return } + if mgr.metrics != nil { + mgr.metrics.RecordCertificateIssuance(elapsed) + } + mgr.setDomainState(d, domainReady, "") now := time.Now() diff --git a/proxy/internal/acme/manager_test.go b/proxy/internal/acme/manager_test.go index f7efe5933..30a27c612 100644 --- a/proxy/internal/acme/manager_test.go +++ b/proxy/internal/acme/manager_test.go @@ -10,7 +10,7 @@ import ( ) func TestHostPolicy(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "") + mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil) mgr.AddDomain("example.com", "acc1", "rp1") // Wait for the background prefetch goroutine to finish so the temp dir @@ -70,7 +70,7 @@ func TestHostPolicy(t *testing.T) { } func TestDomainStates(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "") + mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil) assert.Equal(t, 0, mgr.PendingCerts(), "initially zero") assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains") diff --git a/proxy/internal/metrics/metrics.go b/proxy/internal/metrics/metrics.go index 954020f77..68ff55fe5 100644 --- a/proxy/internal/metrics/metrics.go +++ b/proxy/internal/metrics/metrics.go @@ -1,64 +1,106 @@ package metrics import ( + "context" "net/http" - "strconv" + "sync" "time" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" + "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/responsewriter" ) type Metrics struct { - requestsTotal prometheus.Counter - activeRequests prometheus.Gauge - configuredDomains prometheus.Gauge - pathsPerDomain *prometheus.GaugeVec - requestDuration *prometheus.HistogramVec - backendDuration *prometheus.HistogramVec + ctx context.Context + requestsTotal metric.Int64Counter + activeRequests metric.Int64UpDownCounter + configuredDomains metric.Int64UpDownCounter + totalPaths metric.Int64UpDownCounter + requestDuration metric.Int64Histogram + backendDuration metric.Int64Histogram + certificateIssueDuration metric.Int64Histogram + + mappingsMux sync.Mutex + mappingPaths map[string]int } -func New(reg prometheus.Registerer) *Metrics { - promFactory := promauto.With(reg) - return &Metrics{ - requestsTotal: promFactory.NewCounter(prometheus.CounterOpts{ - Name: "netbird_proxy_requests_total", - Help: "Total number of requests made to the netbird proxy", - }), - activeRequests: promFactory.NewGauge(prometheus.GaugeOpts{ - Name: "netbird_proxy_active_requests_count", - Help: "Current in-flight requests handled by the netbird proxy", - }), - configuredDomains: promFactory.NewGauge(prometheus.GaugeOpts{ - Name: "netbird_proxy_domains_count", - Help: "Current number of domains configured on the netbird proxy", - }), - pathsPerDomain: promFactory.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "netbird_proxy_paths_count", - Help: "Current number of paths configured on the netbird proxy labelled by domain", - }, - []string{"domain"}, - ), - requestDuration: promFactory.NewHistogramVec( - prometheus.HistogramOpts{ - Name: "netbird_proxy_request_duration_seconds", - Help: "Duration of requests made to the netbird proxy", - Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}, - }, - []string{"status", "size", "method", "host", "path"}, - ), - backendDuration: promFactory.NewHistogramVec(prometheus.HistogramOpts{ - Name: "netbird_proxy_backend_duration_seconds", - Help: "Duration of peer round trip time from the netbird proxy", - Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}, - }, - []string{"status", "size", "method", "host", "path"}, - ), +func New(ctx context.Context, meter metric.Meter) (*Metrics, error) { + requestsTotal, err := meter.Int64Counter( + "proxy.http.request.counter", + metric.WithUnit("1"), + metric.WithDescription("Total number of requests made to the netbird proxy"), + ) + if err != nil { + return nil, err } + + activeRequests, err := meter.Int64UpDownCounter( + "proxy.http.active_requests", + metric.WithUnit("1"), + metric.WithDescription("Current in-flight requests handled by the netbird proxy"), + ) + if err != nil { + return nil, err + } + + configuredDomains, err := meter.Int64UpDownCounter( + "proxy.domains.count", + metric.WithUnit("1"), + metric.WithDescription("Current number of domains configured on the netbird proxy"), + ) + if err != nil { + return nil, err + } + + totalPaths, err := meter.Int64UpDownCounter( + "proxy.paths.count", + metric.WithUnit("1"), + metric.WithDescription("Total number of paths configured on the netbird proxy"), + ) + if err != nil { + return nil, err + } + + requestDuration, err := meter.Int64Histogram( + "proxy.http.request.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of requests made to the netbird proxy"), + ) + if err != nil { + return nil, err + } + + backendDuration, err := meter.Int64Histogram( + "proxy.backend.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of peer round trip time from the netbird proxy"), + ) + if err != nil { + return nil, err + } + + certificateIssueDuration, err := meter.Int64Histogram( + "proxy.certificate.issue.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of ACME certificate issuance"), + ) + if err != nil { + return nil, err + } + + return &Metrics{ + ctx: ctx, + requestsTotal: requestsTotal, + activeRequests: activeRequests, + configuredDomains: configuredDomains, + totalPaths: totalPaths, + requestDuration: requestDuration, + backendDuration: backendDuration, + certificateIssueDuration: certificateIssueDuration, + mappingPaths: make(map[string]int), + }, nil } type responseInterceptor struct { @@ -80,23 +122,19 @@ func (w *responseInterceptor) Write(b []byte) (int, error) { func (m *Metrics) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - m.requestsTotal.Inc() - m.activeRequests.Inc() + m.requestsTotal.Add(m.ctx, 1) + m.activeRequests.Add(m.ctx, 1) interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)} start := time.Now() - next.ServeHTTP(interceptor, r) - duration := time.Since(start) + defer func() { + duration := time.Since(start) + m.activeRequests.Add(m.ctx, -1) + m.requestDuration.Record(m.ctx, duration.Milliseconds()) + }() - m.activeRequests.Desc() - m.requestDuration.With(prometheus.Labels{ - "status": strconv.Itoa(interceptor.status), - "size": strconv.Itoa(interceptor.size), - "method": r.Method, - "host": r.Host, - "path": r.URL.Path, - }).Observe(duration.Seconds()) + next.ServeHTTP(interceptor, r) }) } @@ -108,44 +146,52 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { - labels := prometheus.Labels{ - "method": req.Method, - "host": req.Host, - // Fill potentially empty labels with default values to avoid cardinality issues. - "path": "/", - "status": "0", - "size": "0", - } - if req.URL != nil { - labels["path"] = req.URL.Path - } - start := time.Now() res, err := next.RoundTrip(req) duration := time.Since(start) - // Not all labels will be available if there was an error. - if res != nil { - labels["status"] = strconv.Itoa(res.StatusCode) - labels["size"] = strconv.Itoa(int(res.ContentLength)) - } - - m.backendDuration.With(labels).Observe(duration.Seconds()) + m.backendDuration.Record(m.ctx, duration.Milliseconds()) return res, err }) } func (m *Metrics) AddMapping(mapping proxy.Mapping) { - m.configuredDomains.Inc() - m.pathsPerDomain.With(prometheus.Labels{ - "domain": mapping.Host, - }).Set(float64(len(mapping.Paths))) + m.mappingsMux.Lock() + defer m.mappingsMux.Unlock() + + newPathCount := len(mapping.Paths) + oldPathCount, exists := m.mappingPaths[mapping.Host] + + if !exists { + m.configuredDomains.Add(m.ctx, 1) + } + + pathDelta := newPathCount - oldPathCount + if pathDelta != 0 { + m.totalPaths.Add(m.ctx, int64(pathDelta)) + } + + m.mappingPaths[mapping.Host] = newPathCount } func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { - m.configuredDomains.Dec() - m.pathsPerDomain.With(prometheus.Labels{ - "domain": mapping.Host, - }).Set(0) + m.mappingsMux.Lock() + defer m.mappingsMux.Unlock() + + oldPathCount, exists := m.mappingPaths[mapping.Host] + if !exists { + // Nothing to remove + return + } + + m.configuredDomains.Add(m.ctx, -1) + m.totalPaths.Add(m.ctx, -int64(oldPathCount)) + + delete(m.mappingPaths, mapping.Host) +} + +// RecordCertificateIssuance records the duration of a certificate issuance. +func (m *Metrics) RecordCertificateIssuance(duration time.Duration) { + m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds()) } diff --git a/proxy/internal/metrics/metrics_test.go b/proxy/internal/metrics/metrics_test.go index 31e00ae64..f81072eda 100644 --- a/proxy/internal/metrics/metrics_test.go +++ b/proxy/internal/metrics/metrics_test.go @@ -1,13 +1,17 @@ package metrics_test import ( + "context" "net/http" "net/url" + "reflect" "testing" "github.com/google/go-cmp/cmp" + "go.opentelemetry.io/otel/exporters/prometheus" + "go.opentelemetry.io/otel/sdk/metric" + "github.com/netbirdio/netbird/proxy/internal/metrics" - "github.com/prometheus/client_golang/prometheus" ) type testRoundTripper struct { @@ -47,7 +51,19 @@ func TestMetrics_RoundTripper(t *testing.T) { }, } - m := metrics.New(prometheus.NewRegistry()) + exporter, err := prometheus.New() + if err != nil { + t.Fatalf("create prometheus exporter: %v", err) + } + + provider := metric.NewMeterProvider(metric.WithReader(exporter)) + pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath() + meter := provider.Meter(pkg) + + m, err := metrics.New(context.Background(), meter) + if err != nil { + t.Fatalf("create metrics: %v", err) + } for name, test := range tests { t.Run(name, func(t *testing.T) { diff --git a/proxy/server.go b/proxy/server.go index 0d1aa2f6c..123b14648 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -19,14 +19,17 @@ import ( "net/netip" "net/url" "path/filepath" + "reflect" "sync" "time" "github.com/cenkalti/backoff/v4" "github.com/pires/go-proxyproto" - "github.com/prometheus/client_golang/prometheus" + prometheus2 "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/exporters/prometheus" + "go.opentelemetry.io/otel/sdk/metric" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -42,7 +45,7 @@ import ( proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/k8s" - "github.com/netbirdio/netbird/proxy/internal/metrics" + proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/types" @@ -63,7 +66,7 @@ type Server struct { debug *http.Server healthServer *health.Server healthChecker *health.Checker - meter *metrics.Metrics + meter *proxymetrics.Metrics // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) // so they can be closed during graceful shutdown, since http.Server.Shutdown @@ -152,8 +155,19 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.initDefaults() - reg := prometheus.NewRegistry() - s.meter = metrics.New(reg) + exporter, err := prometheus.New() + if err != nil { + return fmt.Errorf("create prometheus exporter: %w", err) + } + + provider := metric.NewMeterProvider(metric.WithReader(exporter)) + pkg := reflect.TypeOf(Server{}).PkgPath() + meter := provider.Meter(pkg) + + s.meter, err = proxymetrics.New(ctx, meter) + if err != nil { + return fmt.Errorf("create metrics: %w", err) + } mgmtConn, err := s.dialManagement() if err != nil { @@ -193,7 +207,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.startDebugEndpoint() - if err := s.startHealthServer(reg); err != nil { + if err := s.startHealthServer(); err != nil { return err } @@ -284,12 +298,12 @@ func (s *Server) startDebugEndpoint() { } // startHealthServer launches the health probe and metrics server. -func (s *Server) startHealthServer(reg *prometheus.Registry) error { +func (s *Server) startHealthServer() error { healthAddr := s.HealthAddress if healthAddr == "" { healthAddr = defaultHealthAddr } - s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) + s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true})) healthListener, err := net.Listen("tcp", healthAddr) if err != nil { return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err) @@ -423,7 +437,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { "acme_server": s.ACMEDirectory, "challenge_type": s.ACMEChallengeType, }).Debug("ACME certificates enabled, configuring certificate manager") - s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod) + s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod, s.meter) if s.ACMEChallengeType == "http-01" { s.http = &http.Server{ diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 7f03d6986..c67231342 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -2822,6 +2822,16 @@ components: type: string description: "City name from geolocation" example: "San Francisco" + bytes_upload: + type: integer + format: int64 + description: "Bytes uploaded (request body size)" + example: 1024 + bytes_download: + type: integer + format: int64 + description: "Bytes downloaded (response body size)" + example: 8192 required: - id - service_id @@ -2831,6 +2841,8 @@ components: - path - duration_ms - status_code + - bytes_upload + - bytes_download ProxyAccessLogsResponse: type: object properties: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index d4a07f806..f218679c0 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1,6 +1,6 @@ // Package api provides primitives to interact with the openapi HTTP API. // -// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT. +// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.6.0 DO NOT EDIT. package api import ( @@ -24,6 +24,22 @@ const ( CreateIntegrationRequestPlatformS3 CreateIntegrationRequestPlatform = "s3" ) +// Valid indicates whether the value is a known member of the CreateIntegrationRequestPlatform enum. +func (e CreateIntegrationRequestPlatform) Valid() bool { + switch e { + case CreateIntegrationRequestPlatformDatadog: + return true + case CreateIntegrationRequestPlatformFirehose: + return true + case CreateIntegrationRequestPlatformGenericHttp: + return true + case CreateIntegrationRequestPlatformS3: + return true + default: + return false + } +} + // Defines values for DNSRecordType. const ( DNSRecordTypeA DNSRecordType = "A" @@ -31,6 +47,20 @@ const ( DNSRecordTypeCNAME DNSRecordType = "CNAME" ) +// Valid indicates whether the value is a known member of the DNSRecordType enum. +func (e DNSRecordType) Valid() bool { + switch e { + case DNSRecordTypeA: + return true + case DNSRecordTypeAAAA: + return true + case DNSRecordTypeCNAME: + return true + default: + return false + } +} + // Defines values for EventActivityCode. const ( EventActivityCodeAccountCreate EventActivityCode = "account.create" @@ -147,12 +177,256 @@ const ( EventActivityCodeUserUnblock EventActivityCode = "user.unblock" ) +// Valid indicates whether the value is a known member of the EventActivityCode enum. +func (e EventActivityCode) Valid() bool { + switch e { + case EventActivityCodeAccountCreate: + return true + case EventActivityCodeAccountDelete: + return true + case EventActivityCodeAccountDnsDomainUpdate: + return true + case EventActivityCodeAccountNetworkRangeUpdate: + return true + case EventActivityCodeAccountPeerInactivityExpirationDisable: + return true + case EventActivityCodeAccountPeerInactivityExpirationEnable: + return true + case EventActivityCodeAccountPeerInactivityExpirationUpdate: + return true + case EventActivityCodeAccountSettingGroupPropagationDisable: + return true + case EventActivityCodeAccountSettingGroupPropagationEnable: + return true + case EventActivityCodeAccountSettingLazyConnectionDisable: + return true + case EventActivityCodeAccountSettingLazyConnectionEnable: + return true + case EventActivityCodeAccountSettingPeerApprovalDisable: + return true + case EventActivityCodeAccountSettingPeerApprovalEnable: + return true + case EventActivityCodeAccountSettingPeerLoginExpirationDisable: + return true + case EventActivityCodeAccountSettingPeerLoginExpirationEnable: + return true + case EventActivityCodeAccountSettingPeerLoginExpirationUpdate: + return true + case EventActivityCodeAccountSettingRoutingPeerDnsResolutionDisable: + return true + case EventActivityCodeAccountSettingRoutingPeerDnsResolutionEnable: + return true + case EventActivityCodeAccountSettingsAutoVersionUpdate: + return true + case EventActivityCodeDashboardLogin: + return true + case EventActivityCodeDnsSettingDisabledManagementGroupAdd: + return true + case EventActivityCodeDnsSettingDisabledManagementGroupDelete: + return true + case EventActivityCodeDnsZoneCreate: + return true + case EventActivityCodeDnsZoneDelete: + return true + case EventActivityCodeDnsZoneRecordCreate: + return true + case EventActivityCodeDnsZoneRecordDelete: + return true + case EventActivityCodeDnsZoneRecordUpdate: + return true + case EventActivityCodeDnsZoneUpdate: + return true + case EventActivityCodeGroupAdd: + return true + case EventActivityCodeGroupDelete: + return true + case EventActivityCodeGroupUpdate: + return true + case EventActivityCodeIdentityproviderCreate: + return true + case EventActivityCodeIdentityproviderDelete: + return true + case EventActivityCodeIdentityproviderUpdate: + return true + case EventActivityCodeIntegrationCreate: + return true + case EventActivityCodeIntegrationDelete: + return true + case EventActivityCodeIntegrationUpdate: + return true + case EventActivityCodeNameserverGroupAdd: + return true + case EventActivityCodeNameserverGroupDelete: + return true + case EventActivityCodeNameserverGroupUpdate: + return true + case EventActivityCodeNetworkCreate: + return true + case EventActivityCodeNetworkDelete: + return true + case EventActivityCodeNetworkResourceCreate: + return true + case EventActivityCodeNetworkResourceDelete: + return true + case EventActivityCodeNetworkResourceUpdate: + return true + case EventActivityCodeNetworkRouterCreate: + return true + case EventActivityCodeNetworkRouterDelete: + return true + case EventActivityCodeNetworkRouterUpdate: + return true + case EventActivityCodeNetworkUpdate: + return true + case EventActivityCodePeerApprovalRevoke: + return true + case EventActivityCodePeerApprove: + return true + case EventActivityCodePeerGroupAdd: + return true + case EventActivityCodePeerGroupDelete: + return true + case EventActivityCodePeerInactivityExpirationDisable: + return true + case EventActivityCodePeerInactivityExpirationEnable: + return true + case EventActivityCodePeerIpUpdate: + return true + case EventActivityCodePeerJobCreate: + return true + case EventActivityCodePeerLoginExpirationDisable: + return true + case EventActivityCodePeerLoginExpirationEnable: + return true + case EventActivityCodePeerLoginExpire: + return true + case EventActivityCodePeerRename: + return true + case EventActivityCodePeerSetupkeyAdd: + return true + case EventActivityCodePeerSshDisable: + return true + case EventActivityCodePeerSshEnable: + return true + case EventActivityCodePeerUserAdd: + return true + case EventActivityCodePersonalAccessTokenCreate: + return true + case EventActivityCodePersonalAccessTokenDelete: + return true + case EventActivityCodePolicyAdd: + return true + case EventActivityCodePolicyDelete: + return true + case EventActivityCodePolicyUpdate: + return true + case EventActivityCodePostureCheckCreate: + return true + case EventActivityCodePostureCheckDelete: + return true + case EventActivityCodePostureCheckUpdate: + return true + case EventActivityCodeResourceGroupAdd: + return true + case EventActivityCodeResourceGroupDelete: + return true + case EventActivityCodeRouteAdd: + return true + case EventActivityCodeRouteDelete: + return true + case EventActivityCodeRouteUpdate: + return true + case EventActivityCodeRuleAdd: + return true + case EventActivityCodeRuleDelete: + return true + case EventActivityCodeRuleUpdate: + return true + case EventActivityCodeServiceCreate: + return true + case EventActivityCodeServiceDelete: + return true + case EventActivityCodeServiceUpdate: + return true + case EventActivityCodeServiceUserCreate: + return true + case EventActivityCodeServiceUserDelete: + return true + case EventActivityCodeSetupkeyAdd: + return true + case EventActivityCodeSetupkeyDelete: + return true + case EventActivityCodeSetupkeyGroupAdd: + return true + case EventActivityCodeSetupkeyGroupDelete: + return true + case EventActivityCodeSetupkeyOveruse: + return true + case EventActivityCodeSetupkeyRevoke: + return true + case EventActivityCodeSetupkeyUpdate: + return true + case EventActivityCodeTransferredOwnerRole: + return true + case EventActivityCodeUserApprove: + return true + case EventActivityCodeUserBlock: + return true + case EventActivityCodeUserCreate: + return true + case EventActivityCodeUserDelete: + return true + case EventActivityCodeUserGroupAdd: + return true + case EventActivityCodeUserGroupDelete: + return true + case EventActivityCodeUserInvite: + return true + case EventActivityCodeUserInviteLinkAccept: + return true + case EventActivityCodeUserInviteLinkCreate: + return true + case EventActivityCodeUserInviteLinkDelete: + return true + case EventActivityCodeUserInviteLinkRegenerate: + return true + case EventActivityCodeUserJoin: + return true + case EventActivityCodeUserPasswordChange: + return true + case EventActivityCodeUserPeerDelete: + return true + case EventActivityCodeUserPeerLogin: + return true + case EventActivityCodeUserReject: + return true + case EventActivityCodeUserRoleUpdate: + return true + case EventActivityCodeUserUnblock: + return true + default: + return false + } +} + // Defines values for GeoLocationCheckAction. const ( GeoLocationCheckActionAllow GeoLocationCheckAction = "allow" GeoLocationCheckActionDeny GeoLocationCheckAction = "deny" ) +// Valid indicates whether the value is a known member of the GeoLocationCheckAction enum. +func (e GeoLocationCheckAction) Valid() bool { + switch e { + case GeoLocationCheckActionAllow: + return true + case GeoLocationCheckActionDeny: + return true + default: + return false + } +} + // Defines values for GroupIssued. const ( GroupIssuedApi GroupIssued = "api" @@ -160,6 +434,20 @@ const ( GroupIssuedJwt GroupIssued = "jwt" ) +// Valid indicates whether the value is a known member of the GroupIssued enum. +func (e GroupIssued) Valid() bool { + switch e { + case GroupIssuedApi: + return true + case GroupIssuedIntegration: + return true + case GroupIssuedJwt: + return true + default: + return false + } +} + // Defines values for GroupMinimumIssued. const ( GroupMinimumIssuedApi GroupMinimumIssued = "api" @@ -167,6 +455,20 @@ const ( GroupMinimumIssuedJwt GroupMinimumIssued = "jwt" ) +// Valid indicates whether the value is a known member of the GroupMinimumIssued enum. +func (e GroupMinimumIssued) Valid() bool { + switch e { + case GroupMinimumIssuedApi: + return true + case GroupMinimumIssuedIntegration: + return true + case GroupMinimumIssuedJwt: + return true + default: + return false + } +} + // Defines values for IdentityProviderType. const ( IdentityProviderTypeEntra IdentityProviderType = "entra" @@ -178,6 +480,28 @@ const ( IdentityProviderTypeZitadel IdentityProviderType = "zitadel" ) +// Valid indicates whether the value is a known member of the IdentityProviderType enum. +func (e IdentityProviderType) Valid() bool { + switch e { + case IdentityProviderTypeEntra: + return true + case IdentityProviderTypeGoogle: + return true + case IdentityProviderTypeMicrosoft: + return true + case IdentityProviderTypeOidc: + return true + case IdentityProviderTypeOkta: + return true + case IdentityProviderTypePocketid: + return true + case IdentityProviderTypeZitadel: + return true + default: + return false + } +} + // Defines values for IngressPortAllocationPortMappingProtocol. const ( IngressPortAllocationPortMappingProtocolTcp IngressPortAllocationPortMappingProtocol = "tcp" @@ -185,6 +509,20 @@ const ( IngressPortAllocationPortMappingProtocolUdp IngressPortAllocationPortMappingProtocol = "udp" ) +// Valid indicates whether the value is a known member of the IngressPortAllocationPortMappingProtocol enum. +func (e IngressPortAllocationPortMappingProtocol) Valid() bool { + switch e { + case IngressPortAllocationPortMappingProtocolTcp: + return true + case IngressPortAllocationPortMappingProtocolTcpudp: + return true + case IngressPortAllocationPortMappingProtocolUdp: + return true + default: + return false + } +} + // Defines values for IngressPortAllocationRequestDirectPortProtocol. const ( IngressPortAllocationRequestDirectPortProtocolTcp IngressPortAllocationRequestDirectPortProtocol = "tcp" @@ -192,6 +530,20 @@ const ( IngressPortAllocationRequestDirectPortProtocolUdp IngressPortAllocationRequestDirectPortProtocol = "udp" ) +// Valid indicates whether the value is a known member of the IngressPortAllocationRequestDirectPortProtocol enum. +func (e IngressPortAllocationRequestDirectPortProtocol) Valid() bool { + switch e { + case IngressPortAllocationRequestDirectPortProtocolTcp: + return true + case IngressPortAllocationRequestDirectPortProtocolTcpudp: + return true + case IngressPortAllocationRequestDirectPortProtocolUdp: + return true + default: + return false + } +} + // Defines values for IngressPortAllocationRequestPortRangeProtocol. const ( IngressPortAllocationRequestPortRangeProtocolTcp IngressPortAllocationRequestPortRangeProtocol = "tcp" @@ -199,6 +551,20 @@ const ( IngressPortAllocationRequestPortRangeProtocolUdp IngressPortAllocationRequestPortRangeProtocol = "udp" ) +// Valid indicates whether the value is a known member of the IngressPortAllocationRequestPortRangeProtocol enum. +func (e IngressPortAllocationRequestPortRangeProtocol) Valid() bool { + switch e { + case IngressPortAllocationRequestPortRangeProtocolTcp: + return true + case IngressPortAllocationRequestPortRangeProtocolTcpudp: + return true + case IngressPortAllocationRequestPortRangeProtocolUdp: + return true + default: + return false + } +} + // Defines values for IntegrationResponsePlatform. const ( IntegrationResponsePlatformDatadog IntegrationResponsePlatform = "datadog" @@ -207,12 +573,40 @@ const ( IntegrationResponsePlatformS3 IntegrationResponsePlatform = "s3" ) +// Valid indicates whether the value is a known member of the IntegrationResponsePlatform enum. +func (e IntegrationResponsePlatform) Valid() bool { + switch e { + case IntegrationResponsePlatformDatadog: + return true + case IntegrationResponsePlatformFirehose: + return true + case IntegrationResponsePlatformGenericHttp: + return true + case IntegrationResponsePlatformS3: + return true + default: + return false + } +} + // Defines values for InvoiceResponseType. const ( InvoiceResponseTypeAccount InvoiceResponseType = "account" InvoiceResponseTypeTenants InvoiceResponseType = "tenants" ) +// Valid indicates whether the value is a known member of the InvoiceResponseType enum. +func (e InvoiceResponseType) Valid() bool { + switch e { + case InvoiceResponseTypeAccount: + return true + case InvoiceResponseTypeTenants: + return true + default: + return false + } +} + // Defines values for JobResponseStatus. const ( JobResponseStatusFailed JobResponseStatus = "failed" @@ -220,11 +614,35 @@ const ( JobResponseStatusSucceeded JobResponseStatus = "succeeded" ) +// Valid indicates whether the value is a known member of the JobResponseStatus enum. +func (e JobResponseStatus) Valid() bool { + switch e { + case JobResponseStatusFailed: + return true + case JobResponseStatusPending: + return true + case JobResponseStatusSucceeded: + return true + default: + return false + } +} + // Defines values for NameserverNsType. const ( NameserverNsTypeUdp NameserverNsType = "udp" ) +// Valid indicates whether the value is a known member of the NameserverNsType enum. +func (e NameserverNsType) Valid() bool { + switch e { + case NameserverNsTypeUdp: + return true + default: + return false + } +} + // Defines values for NetworkResourceType. const ( NetworkResourceTypeDomain NetworkResourceType = "domain" @@ -232,18 +650,56 @@ const ( NetworkResourceTypeSubnet NetworkResourceType = "subnet" ) +// Valid indicates whether the value is a known member of the NetworkResourceType enum. +func (e NetworkResourceType) Valid() bool { + switch e { + case NetworkResourceTypeDomain: + return true + case NetworkResourceTypeHost: + return true + case NetworkResourceTypeSubnet: + return true + default: + return false + } +} + // Defines values for PeerNetworkRangeCheckAction. const ( PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow" PeerNetworkRangeCheckActionDeny PeerNetworkRangeCheckAction = "deny" ) +// Valid indicates whether the value is a known member of the PeerNetworkRangeCheckAction enum. +func (e PeerNetworkRangeCheckAction) Valid() bool { + switch e { + case PeerNetworkRangeCheckActionAllow: + return true + case PeerNetworkRangeCheckActionDeny: + return true + default: + return false + } +} + // Defines values for PolicyRuleAction. const ( PolicyRuleActionAccept PolicyRuleAction = "accept" PolicyRuleActionDrop PolicyRuleAction = "drop" ) +// Valid indicates whether the value is a known member of the PolicyRuleAction enum. +func (e PolicyRuleAction) Valid() bool { + switch e { + case PolicyRuleActionAccept: + return true + case PolicyRuleActionDrop: + return true + default: + return false + } +} + // Defines values for PolicyRuleProtocol. const ( PolicyRuleProtocolAll PolicyRuleProtocol = "all" @@ -253,12 +709,42 @@ const ( PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" ) +// Valid indicates whether the value is a known member of the PolicyRuleProtocol enum. +func (e PolicyRuleProtocol) Valid() bool { + switch e { + case PolicyRuleProtocolAll: + return true + case PolicyRuleProtocolIcmp: + return true + case PolicyRuleProtocolNetbirdSsh: + return true + case PolicyRuleProtocolTcp: + return true + case PolicyRuleProtocolUdp: + return true + default: + return false + } +} + // Defines values for PolicyRuleMinimumAction. const ( PolicyRuleMinimumActionAccept PolicyRuleMinimumAction = "accept" PolicyRuleMinimumActionDrop PolicyRuleMinimumAction = "drop" ) +// Valid indicates whether the value is a known member of the PolicyRuleMinimumAction enum. +func (e PolicyRuleMinimumAction) Valid() bool { + switch e { + case PolicyRuleMinimumActionAccept: + return true + case PolicyRuleMinimumActionDrop: + return true + default: + return false + } +} + // Defines values for PolicyRuleMinimumProtocol. const ( PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" @@ -268,12 +754,42 @@ const ( PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" ) +// Valid indicates whether the value is a known member of the PolicyRuleMinimumProtocol enum. +func (e PolicyRuleMinimumProtocol) Valid() bool { + switch e { + case PolicyRuleMinimumProtocolAll: + return true + case PolicyRuleMinimumProtocolIcmp: + return true + case PolicyRuleMinimumProtocolNetbirdSsh: + return true + case PolicyRuleMinimumProtocolTcp: + return true + case PolicyRuleMinimumProtocolUdp: + return true + default: + return false + } +} + // Defines values for PolicyRuleUpdateAction. const ( PolicyRuleUpdateActionAccept PolicyRuleUpdateAction = "accept" PolicyRuleUpdateActionDrop PolicyRuleUpdateAction = "drop" ) +// Valid indicates whether the value is a known member of the PolicyRuleUpdateAction enum. +func (e PolicyRuleUpdateAction) Valid() bool { + switch e { + case PolicyRuleUpdateActionAccept: + return true + case PolicyRuleUpdateActionDrop: + return true + default: + return false + } +} + // Defines values for PolicyRuleUpdateProtocol. const ( PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" @@ -283,6 +799,24 @@ const ( PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) +// Valid indicates whether the value is a known member of the PolicyRuleUpdateProtocol enum. +func (e PolicyRuleUpdateProtocol) Valid() bool { + switch e { + case PolicyRuleUpdateProtocolAll: + return true + case PolicyRuleUpdateProtocolIcmp: + return true + case PolicyRuleUpdateProtocolNetbirdSsh: + return true + case PolicyRuleUpdateProtocolTcp: + return true + case PolicyRuleUpdateProtocolUdp: + return true + default: + return false + } +} + // Defines values for ResourceType. const ( ResourceTypeDomain ResourceType = "domain" @@ -291,12 +825,40 @@ const ( ResourceTypeSubnet ResourceType = "subnet" ) +// Valid indicates whether the value is a known member of the ResourceType enum. +func (e ResourceType) Valid() bool { + switch e { + case ResourceTypeDomain: + return true + case ResourceTypeHost: + return true + case ResourceTypePeer: + return true + case ResourceTypeSubnet: + return true + default: + return false + } +} + // Defines values for ReverseProxyDomainType. const ( ReverseProxyDomainTypeCustom ReverseProxyDomainType = "custom" ReverseProxyDomainTypeFree ReverseProxyDomainType = "free" ) +// Valid indicates whether the value is a known member of the ReverseProxyDomainType enum. +func (e ReverseProxyDomainType) Valid() bool { + switch e { + case ReverseProxyDomainTypeCustom: + return true + case ReverseProxyDomainTypeFree: + return true + default: + return false + } +} + // Defines values for SentinelOneMatchAttributesNetworkStatus. const ( SentinelOneMatchAttributesNetworkStatusConnected SentinelOneMatchAttributesNetworkStatus = "connected" @@ -304,6 +866,20 @@ const ( SentinelOneMatchAttributesNetworkStatusQuarantined SentinelOneMatchAttributesNetworkStatus = "quarantined" ) +// Valid indicates whether the value is a known member of the SentinelOneMatchAttributesNetworkStatus enum. +func (e SentinelOneMatchAttributesNetworkStatus) Valid() bool { + switch e { + case SentinelOneMatchAttributesNetworkStatusConnected: + return true + case SentinelOneMatchAttributesNetworkStatusDisconnected: + return true + case SentinelOneMatchAttributesNetworkStatusQuarantined: + return true + default: + return false + } +} + // Defines values for ServiceMetaStatus. const ( ServiceMetaStatusActive ServiceMetaStatus = "active" @@ -314,23 +890,77 @@ const ( ServiceMetaStatusTunnelNotCreated ServiceMetaStatus = "tunnel_not_created" ) +// Valid indicates whether the value is a known member of the ServiceMetaStatus enum. +func (e ServiceMetaStatus) Valid() bool { + switch e { + case ServiceMetaStatusActive: + return true + case ServiceMetaStatusCertificateFailed: + return true + case ServiceMetaStatusCertificatePending: + return true + case ServiceMetaStatusError: + return true + case ServiceMetaStatusPending: + return true + case ServiceMetaStatusTunnelNotCreated: + return true + default: + return false + } +} + // Defines values for ServiceTargetProtocol. const ( ServiceTargetProtocolHttp ServiceTargetProtocol = "http" ServiceTargetProtocolHttps ServiceTargetProtocol = "https" ) +// Valid indicates whether the value is a known member of the ServiceTargetProtocol enum. +func (e ServiceTargetProtocol) Valid() bool { + switch e { + case ServiceTargetProtocolHttp: + return true + case ServiceTargetProtocolHttps: + return true + default: + return false + } +} + // Defines values for ServiceTargetTargetType. const ( ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" ServiceTargetTargetTypeResource ServiceTargetTargetType = "resource" ) +// Valid indicates whether the value is a known member of the ServiceTargetTargetType enum. +func (e ServiceTargetTargetType) Valid() bool { + switch e { + case ServiceTargetTargetTypePeer: + return true + case ServiceTargetTargetTypeResource: + return true + default: + return false + } +} + // Defines values for ServiceTargetOptionsPathRewrite. const ( ServiceTargetOptionsPathRewritePreserve ServiceTargetOptionsPathRewrite = "preserve" ) +// Valid indicates whether the value is a known member of the ServiceTargetOptionsPathRewrite enum. +func (e ServiceTargetOptionsPathRewrite) Valid() bool { + switch e { + case ServiceTargetOptionsPathRewritePreserve: + return true + default: + return false + } +} + // Defines values for TenantResponseStatus. const ( TenantResponseStatusActive TenantResponseStatus = "active" @@ -339,6 +969,22 @@ const ( TenantResponseStatusPending TenantResponseStatus = "pending" ) +// Valid indicates whether the value is a known member of the TenantResponseStatus enum. +func (e TenantResponseStatus) Valid() bool { + switch e { + case TenantResponseStatusActive: + return true + case TenantResponseStatusExisting: + return true + case TenantResponseStatusInvited: + return true + case TenantResponseStatusPending: + return true + default: + return false + } +} + // Defines values for UserStatus. const ( UserStatusActive UserStatus = "active" @@ -346,11 +992,35 @@ const ( UserStatusInvited UserStatus = "invited" ) +// Valid indicates whether the value is a known member of the UserStatus enum. +func (e UserStatus) Valid() bool { + switch e { + case UserStatusActive: + return true + case UserStatusBlocked: + return true + case UserStatusInvited: + return true + default: + return false + } +} + // Defines values for WorkloadType. const ( WorkloadTypeBundle WorkloadType = "bundle" ) +// Valid indicates whether the value is a known member of the WorkloadType enum. +func (e WorkloadType) Valid() bool { + switch e { + case WorkloadTypeBundle: + return true + default: + return false + } +} + // Defines values for GetApiEventsNetworkTrafficParamsType. const ( GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP" @@ -359,12 +1029,40 @@ const ( GetApiEventsNetworkTrafficParamsTypeTYPEUNKNOWN GetApiEventsNetworkTrafficParamsType = "TYPE_UNKNOWN" ) +// Valid indicates whether the value is a known member of the GetApiEventsNetworkTrafficParamsType enum. +func (e GetApiEventsNetworkTrafficParamsType) Valid() bool { + switch e { + case GetApiEventsNetworkTrafficParamsTypeTYPEDROP: + return true + case GetApiEventsNetworkTrafficParamsTypeTYPEEND: + return true + case GetApiEventsNetworkTrafficParamsTypeTYPESTART: + return true + case GetApiEventsNetworkTrafficParamsTypeTYPEUNKNOWN: + return true + default: + return false + } +} + // Defines values for GetApiEventsNetworkTrafficParamsConnectionType. const ( GetApiEventsNetworkTrafficParamsConnectionTypeP2P GetApiEventsNetworkTrafficParamsConnectionType = "P2P" GetApiEventsNetworkTrafficParamsConnectionTypeROUTED GetApiEventsNetworkTrafficParamsConnectionType = "ROUTED" ) +// Valid indicates whether the value is a known member of the GetApiEventsNetworkTrafficParamsConnectionType enum. +func (e GetApiEventsNetworkTrafficParamsConnectionType) Valid() bool { + switch e { + case GetApiEventsNetworkTrafficParamsConnectionTypeP2P: + return true + case GetApiEventsNetworkTrafficParamsConnectionTypeROUTED: + return true + default: + return false + } +} + // Defines values for GetApiEventsNetworkTrafficParamsDirection. const ( GetApiEventsNetworkTrafficParamsDirectionDIRECTIONUNKNOWN GetApiEventsNetworkTrafficParamsDirection = "DIRECTION_UNKNOWN" @@ -372,6 +1070,20 @@ const ( GetApiEventsNetworkTrafficParamsDirectionINGRESS GetApiEventsNetworkTrafficParamsDirection = "INGRESS" ) +// Valid indicates whether the value is a known member of the GetApiEventsNetworkTrafficParamsDirection enum. +func (e GetApiEventsNetworkTrafficParamsDirection) Valid() bool { + switch e { + case GetApiEventsNetworkTrafficParamsDirectionDIRECTIONUNKNOWN: + return true + case GetApiEventsNetworkTrafficParamsDirectionEGRESS: + return true + case GetApiEventsNetworkTrafficParamsDirectionINGRESS: + return true + default: + return false + } +} + // Defines values for GetApiEventsProxyParamsSortBy. const ( GetApiEventsProxyParamsSortByAuthMethod GetApiEventsProxyParamsSortBy = "auth_method" @@ -387,12 +1099,54 @@ const ( GetApiEventsProxyParamsSortByUserId GetApiEventsProxyParamsSortBy = "user_id" ) +// Valid indicates whether the value is a known member of the GetApiEventsProxyParamsSortBy enum. +func (e GetApiEventsProxyParamsSortBy) Valid() bool { + switch e { + case GetApiEventsProxyParamsSortByAuthMethod: + return true + case GetApiEventsProxyParamsSortByDuration: + return true + case GetApiEventsProxyParamsSortByHost: + return true + case GetApiEventsProxyParamsSortByMethod: + return true + case GetApiEventsProxyParamsSortByPath: + return true + case GetApiEventsProxyParamsSortByReason: + return true + case GetApiEventsProxyParamsSortBySourceIp: + return true + case GetApiEventsProxyParamsSortByStatusCode: + return true + case GetApiEventsProxyParamsSortByTimestamp: + return true + case GetApiEventsProxyParamsSortByUrl: + return true + case GetApiEventsProxyParamsSortByUserId: + return true + default: + return false + } +} + // Defines values for GetApiEventsProxyParamsSortOrder. const ( GetApiEventsProxyParamsSortOrderAsc GetApiEventsProxyParamsSortOrder = "asc" GetApiEventsProxyParamsSortOrderDesc GetApiEventsProxyParamsSortOrder = "desc" ) +// Valid indicates whether the value is a known member of the GetApiEventsProxyParamsSortOrder enum. +func (e GetApiEventsProxyParamsSortOrder) Valid() bool { + switch e { + case GetApiEventsProxyParamsSortOrderAsc: + return true + case GetApiEventsProxyParamsSortOrderDesc: + return true + default: + return false + } +} + // Defines values for GetApiEventsProxyParamsMethod. const ( GetApiEventsProxyParamsMethodDELETE GetApiEventsProxyParamsMethod = "DELETE" @@ -404,18 +1158,64 @@ const ( GetApiEventsProxyParamsMethodPUT GetApiEventsProxyParamsMethod = "PUT" ) +// Valid indicates whether the value is a known member of the GetApiEventsProxyParamsMethod enum. +func (e GetApiEventsProxyParamsMethod) Valid() bool { + switch e { + case GetApiEventsProxyParamsMethodDELETE: + return true + case GetApiEventsProxyParamsMethodGET: + return true + case GetApiEventsProxyParamsMethodHEAD: + return true + case GetApiEventsProxyParamsMethodOPTIONS: + return true + case GetApiEventsProxyParamsMethodPATCH: + return true + case GetApiEventsProxyParamsMethodPOST: + return true + case GetApiEventsProxyParamsMethodPUT: + return true + default: + return false + } +} + // Defines values for GetApiEventsProxyParamsStatus. const ( GetApiEventsProxyParamsStatusFailed GetApiEventsProxyParamsStatus = "failed" GetApiEventsProxyParamsStatusSuccess GetApiEventsProxyParamsStatus = "success" ) +// Valid indicates whether the value is a known member of the GetApiEventsProxyParamsStatus enum. +func (e GetApiEventsProxyParamsStatus) Valid() bool { + switch e { + case GetApiEventsProxyParamsStatusFailed: + return true + case GetApiEventsProxyParamsStatusSuccess: + return true + default: + return false + } +} + // Defines values for PutApiIntegrationsMspTenantsIdInviteJSONBodyValue. const ( PutApiIntegrationsMspTenantsIdInviteJSONBodyValueAccept PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "accept" PutApiIntegrationsMspTenantsIdInviteJSONBodyValueDecline PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "decline" ) +// Valid indicates whether the value is a known member of the PutApiIntegrationsMspTenantsIdInviteJSONBodyValue enum. +func (e PutApiIntegrationsMspTenantsIdInviteJSONBodyValue) Valid() bool { + switch e { + case PutApiIntegrationsMspTenantsIdInviteJSONBodyValueAccept: + return true + case PutApiIntegrationsMspTenantsIdInviteJSONBodyValueDecline: + return true + default: + return false + } +} + // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { // CityName Commonly used English name of the city @@ -598,7 +1398,7 @@ type BundleParameters struct { // BundleResult defines model for BundleResult. type BundleResult struct { - UploadKey *string `json:"upload_key"` + UploadKey *string `json:"upload_key,omitempty"` } // BundleWorkloadRequest defines model for BundleWorkloadRequest. @@ -1406,9 +2206,9 @@ type JobRequest struct { // JobResponse defines model for JobResponse. type JobResponse struct { - CompletedAt *time.Time `json:"completed_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` CreatedAt time.Time `json:"created_at"` - FailedReason *string `json:"failed_reason"` + FailedReason *string `json:"failed_reason,omitempty"` Id string `json:"id"` Status JobResponseStatus `json:"status"` TriggeredBy string `json:"triggered_by"` @@ -2419,6 +3219,12 @@ type ProxyAccessLog struct { // AuthMethodUsed Authentication method used (e.g., password, pin, oidc) AuthMethodUsed *string `json:"auth_method_used,omitempty"` + // BytesDownload Bytes downloaded (response body size) + BytesDownload int64 `json:"bytes_download"` + + // BytesUpload Bytes uploaded (request body size) + BytesUpload int64 `json:"bytes_upload"` + // CityName City name from geolocation CityName *string `json:"city_name,omitempty"` diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 97a2a4d18..2c66bb946 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.3 +// protoc v6.33.0 // source: management.proto package proto diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 77c8ea4f4..275e8be37 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.3 +// protoc v6.33.0 // source: proxy_service.proto package proto @@ -740,6 +740,8 @@ type AccessLog struct { AuthMechanism string `protobuf:"bytes,11,opt,name=auth_mechanism,json=authMechanism,proto3" json:"auth_mechanism,omitempty"` UserId string `protobuf:"bytes,12,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` AuthSuccess bool `protobuf:"varint,13,opt,name=auth_success,json=authSuccess,proto3" json:"auth_success,omitempty"` + BytesUpload int64 `protobuf:"varint,14,opt,name=bytes_upload,json=bytesUpload,proto3" json:"bytes_upload,omitempty"` + BytesDownload int64 `protobuf:"varint,15,opt,name=bytes_download,json=bytesDownload,proto3" json:"bytes_download,omitempty"` } func (x *AccessLog) Reset() { @@ -865,6 +867,20 @@ func (x *AccessLog) GetAuthSuccess() bool { return false } +func (x *AccessLog) GetBytesUpload() int64 { + if x != nil { + return x.BytesUpload + } + return 0 +} + +func (x *AccessLog) GetBytesDownload() int64 { + if x != nil { + return x.BytesDownload + } + return 0 +} + type AuthenticateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1698,7 +1714,7 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xa0, 0x03, 0x0a, 0x09, 0x41, 0x63, 0x63, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xea, 0x03, 0x0a, 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, @@ -1724,153 +1740,158 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, - 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, 0xb6, 0x01, 0x0a, 0x13, - 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, - 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, - 0x49, 0x64, 0x12, 0x39, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x48, 0x00, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, - 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x22, 0x2d, 0x0a, 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, - 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, - 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, 0x0a, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, - 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, - 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, - 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, - 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xf3, 0x01, 0x0a, 0x17, 0x53, - 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, - 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x2d, 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, - 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, 0x73, 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x49, 0x73, - 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, - 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, - 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x1a, 0x0a, 0x18, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, - 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, - 0x16, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, - 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, - 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, - 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, - 0x6b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, - 0x75, 0x61, 0x72, 0x64, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, - 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, - 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, + 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, + 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x25, + 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, + 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, 0x44, 0x6f, 0x77, + 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, + 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, - 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, - 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, - 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, - 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x8c, - 0x01, 0x0a, 0x17, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, - 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, - 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, - 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, - 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, - 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, - 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x2a, 0x64, 0x0a, - 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, - 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, - 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, - 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, - 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, - 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, - 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, - 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, - 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, - 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, - 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, - 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, - 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, - 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, - 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, - 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, - 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, - 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, + 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x39, 0x0a, 0x08, + 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x73, 0x73, + 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x08, 0x70, + 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, + 0x70, 0x69, 0x6e, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x2d, + 0x0a, 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, + 0x0a, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, + 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, + 0x14, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, + 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xf3, 0x01, 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, + 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, + 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x2d, 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, + 0x73, 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, + 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, + 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x1a, 0x0a, 0x18, 0x53, 0x65, + 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, 0x74, + 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, + 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, + 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, + 0x72, 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x75, + 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, + 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, + 0x72, 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, + 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, + 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, + 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, + 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, + 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, + 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, + 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, + 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, + 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, + 0x6c, 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, + 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, + 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x8c, 0x01, 0x0a, 0x17, 0x56, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, + 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, + 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, + 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, + 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, + 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, + 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, + 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, + 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, + 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, + 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, + 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, + 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, + 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, + 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, + 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, + 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, + 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, + 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, + 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, + 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, + 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, + 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, + 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, + 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, - 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, - 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, - 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, - 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, - 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, 0xfc, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, - 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, - 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, - 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, - 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, - 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, - 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x51, - 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, - 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, - 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, + 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, + 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, + 0x32, 0xfc, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, + 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, - 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0a, - 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, - 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, - 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x56, 0x61, 0x6c, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, - 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, - 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, + 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, + 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, + 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, + 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, + 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, + 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, + 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, + 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, + 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, + 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, + 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index be553095d..195b60f01 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -115,6 +115,8 @@ message AccessLog { string auth_mechanism = 11; string user_id = 12; bool auth_success = 13; + int64 bytes_upload = 14; + int64 bytes_download = 15; } message AuthenticateRequest { From 5585adce18cea6ba6e5673039b5d1f104006cedc Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 9 Mar 2026 19:04:04 +0100 Subject: [PATCH 07/27] [management] add activity events for domains (#5548) * add activity events for domains * fix test * update activity codes * update activity codes --- .../modules/reverseproxy/domain/domain.go | 9 +++++++ .../reverseproxy/domain/manager/manager.go | 27 ++++++++++++++----- management/internals/server/modules.go | 2 +- management/server/activity/codes.go | 11 ++++++++ .../testing/testing_tools/channel/channel.go | 2 +- 5 files changed, 43 insertions(+), 8 deletions(-) diff --git a/management/internals/modules/reverseproxy/domain/domain.go b/management/internals/modules/reverseproxy/domain/domain.go index da3432626..83fd669af 100644 --- a/management/internals/modules/reverseproxy/domain/domain.go +++ b/management/internals/modules/reverseproxy/domain/domain.go @@ -15,3 +15,12 @@ type Domain struct { Type Type `gorm:"-"` Validated bool } + +// EventMeta returns activity event metadata for a domain +func (d *Domain) EventMeta() map[string]any { + return map[string]any{ + "domain": d.Domain, + "target_cluster": d.TargetCluster, + "validated": d.Validated, + } +} diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 12dd051fd..8bbc98726 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -9,6 +9,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -36,16 +38,16 @@ type Manager struct { validator domain.Validator proxyManager proxyManager permissionsManager permissions.Manager + accountManager account.Manager } -func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager { +func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager { return Manager{ - store: store, - proxyManager: proxyMgr, - validator: domain.Validator{ - Resolver: net.DefaultResolver, - }, + store: store, + proxyManager: proxyMgr, + validator: domain.Validator{Resolver: net.DefaultResolver}, permissionsManager: permissionsManager, + accountManager: accountManager, } } @@ -136,6 +138,9 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName if err != nil { return d, fmt.Errorf("create domain in store: %w", err) } + + m.accountManager.StoreEvent(ctx, userID, d.ID, accountID, activity.DomainAdded, d.EventMeta()) + return d, nil } @@ -148,10 +153,18 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s return status.NewPermissionDeniedError() } + d, err := m.store.GetCustomDomain(ctx, accountID, domainID) + if err != nil { + return fmt.Errorf("get domain from store: %w", err) + } + if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil { // TODO: check for "no records" type error. Because that is a success condition. return fmt.Errorf("delete domain from store: %w", err) } + + m.accountManager.StoreEvent(ctx, userID, domainID, accountID, activity.DomainDeleted, d.EventMeta()) + return nil } @@ -218,6 +231,8 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID }).WithError(err).Error("update custom domain in store") return } + + m.accountManager.StoreEvent(context.Background(), userID, domainID, accountID, activity.DomainValidated, d.EventMeta()) } else { log.WithFields(log.Fields{ "accountID": accountID, diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 2383019e2..29a8953ac 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -210,7 +210,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { - m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager()) + m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) return &m }) } diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 53cf30d4c..948d599ba 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -220,6 +220,13 @@ const ( // AccountPeerExposeDisabled indicates that a user disabled peer expose for the account AccountPeerExposeDisabled Activity = 115 + // DomainAdded indicates that a user added a custom domain + DomainAdded Activity = 118 + // DomainDeleted indicates that a user deleted a custom domain + DomainDeleted Activity = 119 + // DomainValidated indicates that a custom domain was validated + DomainValidated Activity = 120 + AccountDeleted Activity = 99999 ) @@ -364,6 +371,10 @@ var activityMap = map[Activity]Code{ AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"}, AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"}, + + DomainAdded: {"Domain added", "domain.add"}, + DomainDeleted: {"Domain deleted", "domain.delete"}, + DomainValidated: {"Domain validated", "domain.validate"}, } // StringCode returns a string code of the activity diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 5e33ad652..462013963 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -108,7 +108,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee t.Fatalf("Failed to create proxy manager: %v", err) } proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) - domainManager := manager.NewManager(store, proxyMgr, permissionsManager) + domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) From 11f891220e2ee5b1395e33a2e7a3672cd07f4bef Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:01:13 +0100 Subject: [PATCH 08/27] [management] create a shallow copy of the account when buffering (#5572) --- management/server/account_request_buffer.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index fa6c45856..e1672c2d0 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -86,7 +86,14 @@ func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, acco result := &AccountResult{Account: account, Err: err} for _, req := range requests { - req.ResultChan <- result + if account != nil { + // Shallow copy the account so each goroutine gets its own struct value. + // This prevents data races when callers mutate fields like Policies. + accountCopy := *account + req.ResultChan <- &AccountResult{Account: &accountCopy, Err: err} + } else { + req.ResultChan <- result + } close(req.ResultChan) } } From 7a23c57cf8157f1c62619e68c6c104202c58a96b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 11 Mar 2026 15:52:42 +0100 Subject: [PATCH 09/27] [self-hosted] Remove extra proxy domain from getting started (#5573) --- infrastructure_files/getting-started.sh | 47 ++----------------------- 1 file changed, 2 insertions(+), 45 deletions(-) diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 7fd87ee8e..70088d66a 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -182,44 +182,6 @@ read_enable_proxy() { return 0 } -read_proxy_domain() { - local suggested_proxy="proxy.${BASE_DOMAIN}" - - echo "" > /dev/stderr - echo "NOTE: The proxy domain must be different from the management domain ($NETBIRD_DOMAIN)" > /dev/stderr - echo "to avoid TLS certificate conflicts." > /dev/stderr - echo "" > /dev/stderr - echo "You also need to add a wildcard DNS record for the proxy domain," > /dev/stderr - echo "e.g. *.${suggested_proxy} pointing to the same server domain as $NETBIRD_DOMAIN with a CNAME record." > /dev/stderr - echo "" > /dev/stderr - echo -n "Enter the domain for the NetBird Proxy (e.g. ${suggested_proxy}): " > /dev/stderr - read -r READ_PROXY_DOMAIN < /dev/tty - - if [[ -z "$READ_PROXY_DOMAIN" ]]; then - echo "The proxy domain cannot be empty." > /dev/stderr - read_proxy_domain - return - fi - - if [[ "$READ_PROXY_DOMAIN" == "$NETBIRD_DOMAIN" ]]; then - echo "" > /dev/stderr - echo "WARNING: The proxy domain cannot be the same as the management domain ($NETBIRD_DOMAIN)." > /dev/stderr - read_proxy_domain - return - fi - - echo ${READ_PROXY_DOMAIN} | grep ${NETBIRD_DOMAIN} > /dev/null - if [[ $? -eq 0 ]]; then - echo "" > /dev/stderr - echo "WARNING: The proxy domain cannot be a subdomain of the management domain ($NETBIRD_DOMAIN)." > /dev/stderr - read_proxy_domain - return - fi - - echo "$READ_PROXY_DOMAIN" - return 0 -} - read_traefik_acme_email() { echo "" > /dev/stderr echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr @@ -334,7 +296,6 @@ initialize_default_values() { # NetBird Proxy configuration ENABLE_PROXY="false" - PROXY_DOMAIN="" PROXY_TOKEN="" return 0 } @@ -364,9 +325,6 @@ configure_reverse_proxy() { if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email) ENABLE_PROXY=$(read_enable_proxy) - if [[ "$ENABLE_PROXY" == "true" ]]; then - PROXY_DOMAIN=$(read_proxy_domain) - fi fi # Handle external Traefik-specific prompts (option 1) @@ -813,7 +771,7 @@ NB_PROXY_MANAGEMENT_ADDRESS=http://netbird-server:80 # Allow insecure gRPC connection to management (required for internal Docker network) NB_PROXY_ALLOW_INSECURE=true # Public URL where this proxy is reachable (used for cluster registration) -NB_PROXY_DOMAIN=$PROXY_DOMAIN +NB_PROXY_DOMAIN=$NETBIRD_DOMAIN NB_PROXY_ADDRESS=:8443 NB_PROXY_TOKEN=$PROXY_TOKEN NB_PROXY_CERTIFICATE_DIRECTORY=/certs @@ -1203,8 +1161,7 @@ print_builtin_traefik_instructions() { echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge." echo " Point your proxy domain to this server's domain address like in the examples below:" echo "" - echo " $PROXY_DOMAIN CNAME $NETBIRD_DOMAIN" - echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN" + echo " *.$NETBIRD_DOMAIN CNAME $NETBIRD_DOMAIN" echo "" fi return 0 From b5489d4986e6131f663ba24e330cbb3af4b92580 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Wed, 11 Mar 2026 18:19:17 +0100 Subject: [PATCH 10/27] [management] set components network map by default and optimize memory usage (#5575) * Network map now defaults to compacted mode at startup; environment parsing issues yield clearer warnings and disabling compacted mode is logged. * **Bug Fixes** * DNS enablement and nameserver selection now correctly respect group membership, reducing incorrect DNS assignments. * **Refactor** * Internal routing and firewall rule generation streamlined for more consistent rule IDs and safer peer handling. * **Performance** * Minor memory and slice allocation improvements for peer/group processing. --- .../network_map/controller/controller.go | 11 +- management/server/types/account_components.go | 4 +- .../server/types/networkmap_components.go | 119 ++++++------------ 3 files changed, 51 insertions(+), 83 deletions(-) diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 121c55ac5..4b414df6f 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -87,9 +87,14 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App newNetworkMapBuilder = false } - compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted)) - if err != nil { - log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err) + compactedNetworkMap := true + compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted) + parsedCompactedNmap, err := strconv.ParseBool(compactedEnv) + if err != nil && len(compactedEnv) > 0 { + log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err) + } + if err == nil && !parsedCompactedNmap { + log.WithContext(ctx).Info("disabling compacted mode") compactedNetworkMap = false } diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go index 1eb25cecc..bd4244546 100644 --- a/management/server/types/account_components.go +++ b/management/server/types/account_components.go @@ -368,7 +368,7 @@ func (a *Account) getPeersGroupsPoliciesRoutes( func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) { peerInGroups := false - filteredPeerIDs := make([]string, 0, len(a.Peers)) + filteredPeerIDs := make([]string, 0, len(groups)) seenPeerIds := make(map[string]struct{}, len(groups)) for _, gid := range groups { @@ -378,7 +378,7 @@ func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerI } if group.IsGroupAll() || len(groups) == 1 { - filteredPeerIDs = filteredPeerIDs[:0] + filteredPeerIDs = make([]string, 0, len(group.Peers)) peerInGroups = false for _, pid := range group.Peers { peer, ok := a.Peers[pid] diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go index ab6b006e6..23d84a994 100644 --- a/management/server/types/networkmap_components.go +++ b/management/server/types/networkmap_components.go @@ -134,7 +134,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { sourcePeers, ) - dnsManagementStatus := c.getPeerDNSManagementStatus(targetPeerID) + dnsManagementStatus := c.getPeerDNSManagementStatusFromGroups(peerGroups) dnsUpdate := nbdns.Config{ ServiceEnable: dnsManagementStatus, } @@ -152,7 +152,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { customZones = append(customZones, c.AccountZones...) dnsUpdate.CustomZones = customZones - dnsUpdate.NameServerGroups = c.getPeerNSGroups(targetPeerID) + dnsUpdate.NameServerGroups = c.getPeerNSGroupsFromGroups(targetPeerID, peerGroups) } return &NetworkMap{ @@ -278,6 +278,16 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( peers := make([]*nbpeer.Peer, 0) return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + protocol := rule.Protocol + if protocol == PolicyRuleProtocolNetbirdSSH { + protocol = PolicyRuleProtocolTCP + } + + protocolStr := string(protocol) + actionStr := string(rule.Action) + dirStr := strconv.Itoa(direction) + portsJoined := strings.Join(rule.Ports, ",") + for _, peer := range groupPeers { if peer == nil { continue @@ -288,21 +298,18 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( peersExists[peer.ID] = struct{}{} } - protocol := rule.Protocol - if protocol == PolicyRuleProtocolNetbirdSSH { - protocol = PolicyRuleProtocolTCP - } + peerIP := net.IP(peer.IP).String() fr := FirewallRule{ PolicyID: rule.ID, - PeerIP: net.IP(peer.IP).String(), + PeerIP: peerIP, Direction: direction, - Action: string(rule.Action), - Protocol: string(protocol), + Action: actionStr, + Protocol: protocolStr, } - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + - fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") + ruleID := rule.ID + peerIP + dirStr + + protocolStr + actionStr + portsJoined if _, ok := rulesExists[ruleID]; ok { continue } @@ -313,13 +320,7 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( continue } - rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{ - ID: rule.ID, - Ports: rule.Ports, - PortRanges: rule.PortRanges, - Protocol: rule.Protocol, - Action: rule.Action, - }, targetPeer)...) + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules @@ -395,7 +396,7 @@ func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID str } func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) { - var peersToConnect []*nbpeer.Peer + peersToConnect := make([]*nbpeer.Peer, 0, len(aclPeers)) var expiredPeers []*nbpeer.Peer for _, p := range aclPeers { @@ -410,35 +411,35 @@ func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.P return peersToConnect, expiredPeers } -func (c *NetworkMapComponents) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := c.GetPeerGroups(peerID) - enabled := true +func (c *NetworkMapComponents) getPeerDNSManagementStatusFromGroups(peerGroups map[string]struct{}) bool { for _, groupID := range c.DNSSettings.DisabledManagementGroups { if _, found := peerGroups[groupID]; found { - enabled = false - break + return false } } - return enabled + return true } -func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup { - groupList := c.GetPeerGroups(peerID) - +func (c *NetworkMapComponents) getPeerNSGroupsFromGroups(peerID string, groupList map[string]struct{}) []*nbdns.NameServerGroup { var peerNSGroups []*nbdns.NameServerGroup + targetPeerInfo := c.GetPeerInfo(peerID) + if targetPeerInfo == nil { + return peerNSGroups + } + + peerIPStr := targetPeerInfo.IP.String() + for _, nsGroup := range c.NameServerGroups { if !nsGroup.Enabled { continue } for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - targetPeerInfo := c.GetPeerInfo(peerID) - if targetPeerInfo != nil && !c.peerIsNameserver(targetPeerInfo, nsGroup) { + if _, found := groupList[gID]; found { + if !c.peerIsNameserver(peerIPStr, nsGroup) { peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break } + break } } } @@ -446,9 +447,9 @@ func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServe return peerNSGroups } -func (c *NetworkMapComponents) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { +func (c *NetworkMapComponents) peerIsNameserver(peerIPStr string, nsGroup *nbdns.NameServerGroup) bool { for _, ns := range nsGroup.NameServers { - if peerInfo.IP.String() == ns.IP.String() { + if peerIPStr == ns.IP.String() { return true } } @@ -489,14 +490,13 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute } seenRoute[r.ID] = struct{}{} - routeObj := c.copyRoute(r) - routeObj.Peer = peerInfo.Key + r.Peer = peerInfo.Key if r.Enabled { - enabledRoutes = append(enabledRoutes, routeObj) + enabledRoutes = append(enabledRoutes, r) return } - disabledRoutes = append(disabledRoutes, routeObj) + disabledRoutes = append(disabledRoutes, r) } for _, r := range c.Routes { @@ -510,7 +510,7 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute continue } - newPeerRoute := c.copyRoute(r) + newPeerRoute := r.Copy() newPeerRoute.Peer = id newPeerRoute.PeerGroups = nil newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) @@ -519,50 +519,13 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute } } if r.Peer == peerID { - takeRoute(c.copyRoute(r)) + takeRoute(r.Copy()) } } return enabledRoutes, disabledRoutes } -func (c *NetworkMapComponents) copyRoute(r *route.Route) *route.Route { - var groups, accessControlGroups, peerGroups []string - var domains domain.List - - if r.Groups != nil { - groups = append([]string{}, r.Groups...) - } - if r.AccessControlGroups != nil { - accessControlGroups = append([]string{}, r.AccessControlGroups...) - } - if r.PeerGroups != nil { - peerGroups = append([]string{}, r.PeerGroups...) - } - if r.Domains != nil { - domains = append(domain.List{}, r.Domains...) - } - - return &route.Route{ - ID: r.ID, - AccountID: r.AccountID, - Network: r.Network, - NetworkType: r.NetworkType, - Description: r.Description, - Peer: r.Peer, - PeerID: r.PeerID, - Metric: r.Metric, - Masquerade: r.Masquerade, - NetID: r.NetID, - Enabled: r.Enabled, - Groups: groups, - AccessControlGroups: accessControlGroups, - PeerGroups: peerGroups, - Domains: domains, - KeepRoute: r.KeepRoute, - SkipAutoApply: r.SkipAutoApply, - } -} func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { var filteredRoutes []*route.Route From d3d6a327e002245ff6cb12350c7c0e2253ee092c Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 11 Mar 2026 19:18:37 +0100 Subject: [PATCH 11/27] [proxy] read cert from disk if available instead of cert manager (#5574) * **New Features** * Asynchronous certificate prefetch that races live issuance with periodic on-disk cache checks to surface certificates faster. * Centralized recording and notification when certificates become available. * New on-disk certificate reading and validation to allow immediate use of cached certs. * **Bug Fixes & Performance** * Optimized retrieval by polling disk while fetching in background to reduce latency. * Added cancellation and timeout handling to fail stalled certificate operations reliably. --- proxy/internal/acme/manager.go | 117 ++++++++++++++++++++++++++++----- 1 file changed, 101 insertions(+), 16 deletions(-) diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index ebc15314b..b1e532e83 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -7,9 +7,12 @@ import ( "encoding/asn1" "encoding/base64" "encoding/binary" + "encoding/pem" "fmt" + "math/rand/v2" "net" "slices" + "strings" "sync" "time" @@ -137,7 +140,12 @@ func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) { // It acquires a distributed lock to prevent multiple replicas from issuing // duplicate ACME requests. The second replica will block until the first // finishes, then find the certificate in the cache. +// ACME and periodic disk reads race; whichever produces a valid certificate +// first wins. This handles cases where locking is unreliable and another +// replica already wrote the cert to the shared cache. func (mgr *Manager) prefetchCertificate(d domain.Domain) { + time.Sleep(time.Duration(rand.IntN(200)) * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -153,26 +161,105 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { defer unlock() } - hello := &tls.ClientHelloInfo{ - ServerName: name, - Conn: &dummyConn{ctx: ctx}, - } - - start := time.Now() - cert, err := mgr.GetCertificate(hello) - elapsed := time.Since(start) - if err != nil { - mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), err) - mgr.setDomainState(d, domainFailed, err.Error()) + if cert, err := mgr.readCertFromDisk(ctx, name); err == nil { + mgr.logger.Infof("certificate for domain %q already on disk, skipping ACME", name) + mgr.recordAndNotify(ctx, d, name, cert, 0) return } - if mgr.metrics != nil { + // Run ACME in a goroutine so we can race it against periodic disk reads. + // autocert uses its own internal context and cannot be cancelled externally. + type acmeResult struct { + cert *tls.Certificate + err error + } + acmeCh := make(chan acmeResult, 1) + hello := &tls.ClientHelloInfo{ServerName: name, Conn: &dummyConn{ctx: ctx}} + go func() { + cert, err := mgr.GetCertificate(hello) + acmeCh <- acmeResult{cert, err} + }() + + start := time.Now() + diskTicker := time.NewTicker(5 * time.Second) + defer diskTicker.Stop() + + for { + select { + case res := <-acmeCh: + elapsed := time.Since(start) + if res.err != nil { + mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), res.err) + mgr.setDomainState(d, domainFailed, res.err.Error()) + return + } + mgr.recordAndNotify(ctx, d, name, res.cert, elapsed) + return + + case <-diskTicker.C: + cert, err := mgr.readCertFromDisk(context.Background(), name) + if err != nil { + continue + } + mgr.logger.Infof("certificate for domain %q appeared on disk after %s", name, time.Since(start).Round(time.Millisecond)) + // Drain the ACME goroutine before marking ready — autocert holds + // an internal write lock on certState while ACME is in flight. + go func() { + select { + case <-acmeCh: + default: + } + mgr.recordAndNotify(context.Background(), d, name, cert, 0) + }() + return + + case <-ctx.Done(): + mgr.logger.Warnf("prefetch certificate for domain %q timed out", name) + mgr.setDomainState(d, domainFailed, ctx.Err().Error()) + return + } + } +} + +// readCertFromDisk reads and parses a certificate directly from the autocert +// DirCache, bypassing autocert's internal certState mutex. Safe to call +// concurrently with an in-flight ACME request for the same domain. +func (mgr *Manager) readCertFromDisk(ctx context.Context, name string) (*tls.Certificate, error) { + if mgr.Cache == nil { + return nil, fmt.Errorf("no cache configured") + } + data, err := mgr.Cache.Get(ctx, name) + if err != nil { + return nil, err + } + privBlock, certsPEM := pem.Decode(data) + if privBlock == nil || !strings.Contains(privBlock.Type, "PRIVATE") { + return nil, fmt.Errorf("no private key in cache for %q", name) + } + cert, err := tls.X509KeyPair(certsPEM, pem.EncodeToMemory(privBlock)) + if err != nil { + return nil, fmt.Errorf("parse cached certificate for %q: %w", name, err) + } + if len(cert.Certificate) > 0 { + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("parse leaf for %q: %w", name, err) + } + if time.Now().After(leaf.NotAfter) { + return nil, fmt.Errorf("cached certificate for %q expired at %s", name, leaf.NotAfter) + } + cert.Leaf = leaf + } + return &cert, nil +} + +// recordAndNotify records metrics, marks the domain ready, logs cert details, +// and notifies the cert notifier. +func (mgr *Manager) recordAndNotify(ctx context.Context, d domain.Domain, name string, cert *tls.Certificate, elapsed time.Duration) { + if elapsed > 0 && mgr.metrics != nil { mgr.metrics.RecordCertificateIssuance(elapsed) } - mgr.setDomainState(d, domainReady, "") - now := time.Now() if cert != nil && cert.Leaf != nil { leaf := cert.Leaf @@ -188,11 +275,9 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { } else { mgr.logger.Infof("certificate for domain %q ready in %s", name, elapsed.Round(time.Millisecond)) } - mgr.mu.RLock() info := mgr.domains[d] mgr.mu.RUnlock() - if info != nil && mgr.certNotifier != nil { if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil { mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err) From 8f389fef19cf31fb74e219e03ca392b43cbe96d4 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:57:36 +0100 Subject: [PATCH 12/27] [management] fix some concurrency potential issues (#5584) --- management/server/account_request_buffer.go | 17 ++++++++++---- .../server/http/middleware/bypass/bypass.go | 23 +++++++++++++------ management/server/job/channel.go | 8 ++++++- management/server/types/network.go | 2 ++ 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index e1672c2d0..ac53a9fa8 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -63,11 +63,20 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID) startTime := time.Now() - ac.getAccountRequestCh <- req - result := <-req.ResultChan - log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime)) - return result.Account, result.Err + select { + case <-ctx.Done(): + return nil, ctx.Err() + case ac.getAccountRequestCh <- req: + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-req.ResultChan: + log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime)) + return result.Account, result.Err + } } func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) { diff --git a/management/server/http/middleware/bypass/bypass.go b/management/server/http/middleware/bypass/bypass.go index 9447704cb..ddece7152 100644 --- a/management/server/http/middleware/bypass/bypass.go +++ b/management/server/http/middleware/bypass/bypass.go @@ -51,19 +51,28 @@ func GetList() []string { // This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication. func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool { byPassMutex.RLock() - defer byPassMutex.RUnlock() - + var matched bool for bypassPath := range bypassPaths { - matched, err := path.Match(bypassPath, requestPath) + m, err := path.Match(bypassPath, requestPath) if err != nil { - log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err) + list := make([]string, 0, len(bypassPaths)) + for k := range bypassPaths { + list = append(list, k) + } + log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %v: %v", bypassPath, requestPath, list, err) continue } - if matched { - h.ServeHTTP(w, r) - return true + if m { + matched = true + break } } + byPassMutex.RUnlock() + + if matched { + h.ServeHTTP(w, r) + return true + } return false } diff --git a/management/server/job/channel.go b/management/server/job/channel.go index c4dc98a68..c4454c4c9 100644 --- a/management/server/job/channel.go +++ b/management/server/job/channel.go @@ -28,7 +28,13 @@ func NewChannel() *Channel { return jc } -func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) error { +func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) (err error) { + defer func() { + if r := recover(); r != nil { + err = ErrJobChannelClosed + } + }() + select { case <-ctx.Done(): return ctx.Err() diff --git a/management/server/types/network.go b/management/server/types/network.go index d3708d80a..0d13de10f 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -152,6 +152,8 @@ func (n *Network) CurrentSerial() uint64 { } func (n *Network) Copy() *Network { + n.Mu.Lock() + defer n.Mu.Unlock() return &Network{ Identifier: n.Identifier, Net: n.Net, From c545689448b0630502242a5935aeff2f73d52a87 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:00:28 +0100 Subject: [PATCH 13/27] [proxy] Wildcard certificate support (#5583) --- proxy/cmd/proxy/cmd/root.go | 3 + proxy/internal/acme/manager.go | 227 +++++++++++++++++++++++++--- proxy/internal/acme/manager_test.go | 206 ++++++++++++++++++++++++- proxy/internal/certwatch/watcher.go | 7 + proxy/server.go | 34 ++++- 5 files changed, 455 insertions(+), 22 deletions(-) diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 50aa38b29..61ed5871e 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -53,6 +53,7 @@ var ( certFile string certKeyFile string certLockMethod string + wildcardCertDir string wgPort int proxyProtocol bool preSharedKey string @@ -87,6 +88,7 @@ func init() { rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory") rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory") rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") + rootCmd.Flags().StringVar(&wildcardCertDir, "wildcard-cert-dir", envStringOrDefault("NB_PROXY_WILDCARD_CERT_DIR", ""), "Directory containing wildcard certificate pairs (.crt/.key). Wildcard patterns are extracted from SANs automatically") rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") @@ -162,6 +164,7 @@ func runServer(cmd *cobra.Command, args []string) error { ForwardedProto: forwardedProto, TrustedProxies: parsedTrustedProxies, CertLockMethod: nbacme.CertLockMethod(certLockMethod), + WildcardCertDir: wildcardCertDir, WireguardPort: wgPort, ProxyProtocol: proxyProtocol, PreSharedKey: preSharedKey, diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index b1e532e83..395da7d88 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -11,6 +11,8 @@ import ( "fmt" "math/rand/v2" "net" + "os" + "path/filepath" "slices" "strings" "sync" @@ -20,6 +22,7 @@ import ( "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" + "github.com/netbirdio/netbird/proxy/internal/certwatch" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -49,6 +52,34 @@ type metricsRecorder interface { RecordCertificateIssuance(duration time.Duration) } +// wildcardEntry maps a domain suffix (e.g. ".example.com") to a certwatch +// watcher that hot-reloads the corresponding wildcard certificate from disk. +type wildcardEntry struct { + suffix string // e.g. ".example.com" + pattern string // e.g. "*.example.com" + watcher *certwatch.Watcher +} + +// ManagerConfig holds the configuration values for the ACME certificate manager. +type ManagerConfig struct { + // CertDir is the directory used for caching ACME certificates. + CertDir string + // ACMEURL is the ACME directory URL (e.g. Let's Encrypt). + ACMEURL string + // EABKID and EABHMACKey are optional External Account Binding credentials + // required by some CAs (e.g. ZeroSSL). EABHMACKey is the base64 + // URL-encoded string provided by the CA. + EABKID string + EABHMACKey string + // LockMethod controls the cross-replica coordination strategy. + LockMethod CertLockMethod + // WildcardDir is an optional path to a directory containing wildcard + // certificate pairs (.crt / .key). Wildcard patterns are + // extracted from the certificates' SAN lists. Domains matching a + // wildcard are served from disk; all others go through ACME. + WildcardDir string +} + // Manager wraps autocert.Manager with domain tracking and cross-replica // coordination via a pluggable locking strategy. The locker prevents // duplicate ACME requests when multiple replicas share a certificate cache. @@ -60,54 +91,182 @@ type Manager struct { mu sync.RWMutex domains map[domain.Domain]*domainInfo + // wildcards holds all loaded wildcard certificates, keyed by suffix. + wildcards []wildcardEntry + certNotifier certificateNotifier logger *log.Logger metrics metricsRecorder } -// NewManager creates a new ACME certificate manager. The certDir is used -// for caching certificates. The lockMethod controls cross-replica -// coordination strategy (see CertLockMethod constants). -// eabKID and eabHMACKey are optional External Account Binding credentials -// required for some CAs like ZeroSSL. The eabHMACKey should be the base64 -// URL-encoded string provided by the CA. -func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod, metrics metricsRecorder) *Manager { +// NewManager creates a new ACME certificate manager. +func NewManager(cfg ManagerConfig, notifier certificateNotifier, logger *log.Logger, metrics metricsRecorder) (*Manager, error) { if logger == nil { logger = log.StandardLogger() } mgr := &Manager{ - certDir: certDir, - locker: newCertLocker(lockMethod, certDir, logger), + certDir: cfg.CertDir, + locker: newCertLocker(cfg.LockMethod, cfg.CertDir, logger), domains: make(map[domain.Domain]*domainInfo), certNotifier: notifier, logger: logger, metrics: metrics, } + if cfg.WildcardDir != "" { + entries, err := loadWildcardDir(cfg.WildcardDir, logger) + if err != nil { + return nil, fmt.Errorf("load wildcard certificates from %q: %w", cfg.WildcardDir, err) + } + mgr.wildcards = entries + } + var eab *acme.ExternalAccountBinding - if eabKID != "" && eabHMACKey != "" { - decodedKey, err := base64.RawURLEncoding.DecodeString(eabHMACKey) + if cfg.EABKID != "" && cfg.EABHMACKey != "" { + decodedKey, err := base64.RawURLEncoding.DecodeString(cfg.EABHMACKey) if err != nil { logger.Errorf("failed to decode EAB HMAC key: %v", err) } else { eab = &acme.ExternalAccountBinding{ - KID: eabKID, + KID: cfg.EABKID, Key: decodedKey, } - logger.Infof("configured External Account Binding with KID: %s", eabKID) + logger.Infof("configured External Account Binding with KID: %s", cfg.EABKID) } } mgr.Manager = &autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: mgr.hostPolicy, - Cache: autocert.DirCache(certDir), + Cache: autocert.DirCache(cfg.CertDir), ExternalAccountBinding: eab, Client: &acme.Client{ - DirectoryURL: acmeURL, + DirectoryURL: cfg.ACMEURL, }, } - return mgr + return mgr, nil +} + +// WatchWildcards starts watching all wildcard certificate files for changes. +// It blocks until ctx is cancelled. It is a no-op if no wildcards are loaded. +func (mgr *Manager) WatchWildcards(ctx context.Context) { + if len(mgr.wildcards) == 0 { + return + } + seen := make(map[*certwatch.Watcher]struct{}) + var wg sync.WaitGroup + for i := range mgr.wildcards { + w := mgr.wildcards[i].watcher + if _, ok := seen[w]; ok { + continue + } + seen[w] = struct{}{} + wg.Add(1) + go func() { + defer wg.Done() + w.Watch(ctx) + }() + } + wg.Wait() +} + +// loadWildcardDir scans dir for .crt files, pairs each with a matching .key +// file, loads them, and extracts wildcard SANs (*.example.com) to build +// the suffix lookup entries. +func loadWildcardDir(dir string, logger *log.Logger) ([]wildcardEntry, error) { + crtFiles, err := filepath.Glob(filepath.Join(dir, "*.crt")) + if err != nil { + return nil, fmt.Errorf("glob certificate files: %w", err) + } + + if len(crtFiles) == 0 { + return nil, fmt.Errorf("no .crt files found in %s", dir) + } + + var entries []wildcardEntry + + for _, crtPath := range crtFiles { + base := strings.TrimSuffix(filepath.Base(crtPath), ".crt") + keyPath := filepath.Join(dir, base+".key") + if _, err := os.Stat(keyPath); err != nil { + logger.Warnf("skipping %s: no matching key file %s", crtPath, keyPath) + continue + } + + watcher, err := certwatch.NewWatcher(crtPath, keyPath, logger) + if err != nil { + logger.Warnf("skipping %s: %v", crtPath, err) + continue + } + + leaf := watcher.Leaf() + if leaf == nil { + logger.Warnf("skipping %s: no parsed leaf certificate", crtPath) + continue + } + + for _, san := range leaf.DNSNames { + suffix, ok := parseWildcard(san) + if !ok { + continue + } + entries = append(entries, wildcardEntry{ + suffix: suffix, + pattern: san, + watcher: watcher, + }) + logger.Infof("wildcard certificate loaded: %s (from %s)", san, filepath.Base(crtPath)) + } + } + + if len(entries) == 0 { + return nil, fmt.Errorf("no wildcard SANs (*.example.com) found in certificates in %s", dir) + } + + return entries, nil +} + +// parseWildcard validates a wildcard domain pattern like "*.example.com" +// and returns the suffix ".example.com" for matching. +func parseWildcard(pattern string) (suffix string, ok bool) { + if !strings.HasPrefix(pattern, "*.") { + return "", false + } + parent := pattern[1:] // ".example.com" + if strings.Count(parent, ".") < 1 { + return "", false + } + return strings.ToLower(parent), true +} + +// findWildcardEntry returns the wildcard entry that covers host, or nil. +func (mgr *Manager) findWildcardEntry(host string) *wildcardEntry { + if len(mgr.wildcards) == 0 { + return nil + } + host = strings.ToLower(host) + for i := range mgr.wildcards { + e := &mgr.wildcards[i] + if !strings.HasSuffix(host, e.suffix) { + continue + } + // Single-level match: prefix before suffix must have no dots. + prefix := strings.TrimSuffix(host, e.suffix) + if len(prefix) > 0 && !strings.Contains(prefix, ".") { + return e + } + } + return nil +} + +// WildcardPatterns returns the wildcard patterns that are currently loaded. +func (mgr *Manager) WildcardPatterns() []string { + patterns := make([]string, len(mgr.wildcards)) + for i, e := range mgr.wildcards { + patterns[i] = e.pattern + } + slices.Sort(patterns) + return patterns } func (mgr *Manager) hostPolicy(_ context.Context, host string) error { @@ -123,8 +282,39 @@ func (mgr *Manager) hostPolicy(_ context.Context, host string) error { return nil } -// AddDomain registers a domain for ACME certificate prefetching. -func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) { +// GetCertificate returns the TLS certificate for the given ClientHello. +// If the requested domain matches a loaded wildcard, the static wildcard +// certificate is returned. Otherwise, the ACME autocert manager handles +// the request. +func (mgr *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if e := mgr.findWildcardEntry(hello.ServerName); e != nil { + return e.watcher.GetCertificate(hello) + } + return mgr.Manager.GetCertificate(hello) +} + +// AddDomain registers a domain for certificate management. Domains that +// match a loaded wildcard are marked ready immediately (they use the +// static wildcard certificate) and the method returns true. All other +// domains go through ACME prefetch and the method returns false. +// +// When AddDomain returns true the caller is responsible for sending any +// certificate-ready notifications after the surrounding operation (e.g. +// mapping update) has committed successfully. +func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) (wildcardHit bool) { + name := d.PunycodeString() + if e := mgr.findWildcardEntry(name); e != nil { + mgr.mu.Lock() + mgr.domains[d] = &domainInfo{ + accountID: accountID, + serviceID: serviceID, + state: domainReady, + } + mgr.mu.Unlock() + mgr.logger.Debugf("domain %q matches wildcard %q, using static certificate", name, e.pattern) + return true + } + mgr.mu.Lock() mgr.domains[d] = &domainInfo{ accountID: accountID, @@ -134,6 +324,7 @@ func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) { mgr.mu.Unlock() go mgr.prefetchCertificate(d) + return false } // prefetchCertificate proactively triggers certificate generation for a domain. diff --git a/proxy/internal/acme/manager_test.go b/proxy/internal/acme/manager_test.go index 30a27c612..9a3ed9efd 100644 --- a/proxy/internal/acme/manager_test.go +++ b/proxy/internal/acme/manager_test.go @@ -2,6 +2,16 @@ package acme import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" "testing" "time" @@ -10,7 +20,8 @@ import ( ) func TestHostPolicy(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil) + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) mgr.AddDomain("example.com", "acc1", "rp1") // Wait for the background prefetch goroutine to finish so the temp dir @@ -70,7 +81,8 @@ func TestHostPolicy(t *testing.T) { } func TestDomainStates(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil) + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) assert.Equal(t, 0, mgr.PendingCerts(), "initially zero") assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains") @@ -100,3 +112,193 @@ func TestDomainStates(t *testing.T) { assert.Contains(t, failed, "b.example.com") assert.Empty(t, mgr.ReadyDomains()) } + +func TestParseWildcard(t *testing.T) { + tests := []struct { + pattern string + wantSuffix string + wantOK bool + }{ + {"*.example.com", ".example.com", true}, + {"*.foo.example.com", ".foo.example.com", true}, + {"*.COM", ".com", true}, // single-label TLD + {"example.com", "", false}, // no wildcard prefix + {"*example.com", "", false}, // missing dot + {"**.example.com", "", false}, // double star + {"", "", false}, + } + + for _, tc := range tests { + t.Run(tc.pattern, func(t *testing.T) { + suffix, ok := parseWildcard(tc.pattern) + assert.Equal(t, tc.wantOK, ok) + if ok { + assert.Equal(t, tc.wantSuffix, suffix) + } + }) + } +} + +func TestMatchesWildcard(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + tests := []struct { + host string + match bool + }{ + {"foo.example.com", true}, + {"bar.example.com", true}, + {"FOO.Example.COM", true}, // case insensitive + {"example.com", false}, // bare parent + {"sub.foo.example.com", false}, // multi-level + {"notexample.com", false}, + {"", false}, + } + + for _, tc := range tests { + t.Run(tc.host, func(t *testing.T) { + assert.Equal(t, tc.match, mgr.findWildcardEntry(tc.host) != nil) + }) + } +} + +// generateSelfSignedCert creates a temporary self-signed certificate and key +// for testing purposes. The baseName controls the output filenames: +// .crt and .key. +func generateSelfSignedCert(t *testing.T, dir, baseName string, dnsNames ...string) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: dnsNames[0]}, + DNSNames: dnsNames, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + certFile, err := os.Create(filepath.Join(dir, baseName+".crt")) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + keyDER, err := x509.MarshalECPrivateKey(key) + require.NoError(t, err) + keyFile, err := os.Create(filepath.Join(dir, baseName+".key")) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + require.NoError(t, keyFile.Close()) +} + +func TestWildcardAddDomainSkipsACME(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + // Add a wildcard-matching domain — should be immediately ready. + mgr.AddDomain("foo.example.com", "acc1", "svc1") + assert.Equal(t, 0, mgr.PendingCerts(), "wildcard domain should not be pending") + assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains()) + + // Add a non-wildcard domain — should go through ACME (pending then failed). + mgr.AddDomain("other.net", "acc2", "svc2") + assert.Equal(t, 2, mgr.TotalDomains()) + + // Wait for the ACME prefetch to fail. + assert.Eventually(t, func() bool { + return mgr.PendingCerts() == 0 + }, 30*time.Second, 100*time.Millisecond) + + assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains()) + assert.Contains(t, mgr.FailedDomains(), "other.net") +} + +func TestWildcardGetCertificate(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + mgr.AddDomain("foo.example.com", "acc1", "svc1") + + // GetCertificate for a wildcard-matching domain should return the static cert. + cert, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + require.NotNil(t, cert) + assert.Contains(t, cert.Leaf.DNSNames, "*.example.com") +} + +func TestMultipleWildcards(t *testing.T) { + wcDir := t.TempDir() + generateSelfSignedCert(t, wcDir, "example", "*.example.com") + generateSelfSignedCert(t, wcDir, "other", "*.other.org") + + acmeDir := t.TempDir() + mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.NoError(t, err) + + assert.ElementsMatch(t, []string{"*.example.com", "*.other.org"}, mgr.WildcardPatterns()) + + // Both wildcards should resolve. + mgr.AddDomain("foo.example.com", "acc1", "svc1") + mgr.AddDomain("bar.other.org", "acc2", "svc2") + + assert.Equal(t, 0, mgr.PendingCerts()) + assert.ElementsMatch(t, []string{"foo.example.com", "bar.other.org"}, mgr.ReadyDomains()) + + // GetCertificate routes to the correct cert. + cert1, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + assert.Contains(t, cert1.Leaf.DNSNames, "*.example.com") + + cert2, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "bar.other.org"}) + require.NoError(t, err) + assert.Contains(t, cert2.Leaf.DNSNames, "*.other.org") + + // Non-matching domain falls through to ACME. + mgr.AddDomain("custom.net", "acc3", "svc3") + assert.Eventually(t, func() bool { + return mgr.PendingCerts() == 0 + }, 30*time.Second, 100*time.Millisecond) + assert.Contains(t, mgr.FailedDomains(), "custom.net") +} + +func TestWildcardDirEmpty(t *testing.T) { + wcDir := t.TempDir() + // Empty directory — no .crt files. + _, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no .crt files found") +} + +func TestWildcardDirNonWildcardCert(t *testing.T) { + wcDir := t.TempDir() + // Certificate without a wildcard SAN. + generateSelfSignedCert(t, wcDir, "plain", "plain.example.com") + + _, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no wildcard SANs") +} + +func TestNoWildcardDir(t *testing.T) { + // Empty string means no wildcard dir — pure ACME mode. + mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) + require.NoError(t, err) + assert.Empty(t, mgr.WildcardPatterns()) +} diff --git a/proxy/internal/certwatch/watcher.go b/proxy/internal/certwatch/watcher.go index 78ad1ab7c..6366a53c6 100644 --- a/proxy/internal/certwatch/watcher.go +++ b/proxy/internal/certwatch/watcher.go @@ -67,6 +67,13 @@ func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, erro return w.cert, nil } +// Leaf returns the parsed leaf certificate, or nil if not yet loaded. +func (w *Watcher) Leaf() *x509.Certificate { + w.mu.RLock() + defer w.mu.RUnlock() + return w.leaf +} + // Watch starts watching for certificate file changes. It blocks until // ctx is cancelled. It uses fsnotify for immediate detection and falls // back to polling if fsnotify is unavailable (e.g. on NFS). diff --git a/proxy/server.go b/proxy/server.go index 123b14648..62e8368e6 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -97,6 +97,11 @@ type Server struct { // CertLockMethod controls how ACME certificate locks are coordinated // across replicas. Default: CertLockAuto (detect environment). CertLockMethod acme.CertLockMethod + // WildcardCertDir is an optional directory containing wildcard certificate + // pairs (.crt / .key). Wildcard patterns are extracted from + // the certificates' SAN lists. Matching domains use these static certs + // instead of ACME. + WildcardCertDir string // DebugEndpointEnabled enables the debug HTTP endpoint. DebugEndpointEnabled bool @@ -437,7 +442,20 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { "acme_server": s.ACMEDirectory, "challenge_type": s.ACMEChallengeType, }).Debug("ACME certificates enabled, configuring certificate manager") - s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod, s.meter) + var err error + s.acme, err = acme.NewManager(acme.ManagerConfig{ + CertDir: s.CertificateDirectory, + ACMEURL: s.ACMEDirectory, + EABKID: s.ACMEEABKID, + EABHMACKey: s.ACMEEABHMACKey, + LockMethod: s.CertLockMethod, + WildcardDir: s.WildcardCertDir, + }, s, s.Logger, s.meter) + if err != nil { + return nil, fmt.Errorf("create ACME manager: %w", err) + } + + go s.acme.WatchWildcards(ctx) if s.ACMEChallengeType == "http-01" { s.http = &http.Server{ @@ -453,6 +471,10 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { } tlsConfig = s.acme.TLSConfig() + // autocert.Manager.TLSConfig() wires its own GetCertificate, which + // bypasses our override that checks wildcards first. + tlsConfig.GetCertificate = s.acme.GetCertificate + // ServerName needs to be set to allow for ACME to work correctly // when using CNAME URLs to access the proxy. tlsConfig.ServerName = s.ProxyURL @@ -675,8 +697,9 @@ func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) er if err := s.netbird.AddPeer(ctx, accountID, d, authToken, serviceID); err != nil { return fmt.Errorf("create peer for domain %q: %w", d, err) } + var wildcardHit bool if s.acme != nil { - s.acme.AddDomain(d, string(accountID), serviceID) + wildcardHit = s.acme.AddDomain(d, string(accountID), serviceID) } // Pass the mapping through to the update function to avoid duplicating the @@ -686,6 +709,13 @@ func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) er s.removeMapping(ctx, mapping) return fmt.Errorf("update mapping for domain %q: %w", d, err) } + + if wildcardHit { + if err := s.NotifyCertificateIssued(ctx, string(accountID), serviceID, string(d)); err != nil { + s.Logger.Warnf("notify certificate ready for domain %q: %v", d, err) + } + } + return nil } From e50e124e70b92510fb69c4928fc81ab9bc9470c3 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 12 Mar 2026 17:12:26 +0100 Subject: [PATCH 14/27] [proxy] Fix domain switching update (#5585) --- .../internals/modules/reverseproxy/service/manager/manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 56a1fc98a..cae3d3bda 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -392,7 +392,7 @@ func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID oidcCfg := m.proxyController.GetOIDCValidationConfig() switch { - case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster: + case updateInfo.domainChanged || updateInfo.oldCluster != s.ProxyCluster: m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster) case !s.Enabled && updateInfo.serviceEnabledChanged: From 967c6f3cd34e36a2f4e0199ea5691a6a6bf8e24d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 13 Mar 2026 09:47:00 +0100 Subject: [PATCH 15/27] [misc] Add GPG signing key support for rpm packages (#5581) * [misc] Add GPG signing key support for deb and rpm packages * [misc] Improve GPG key management for deb and rpm signing * [misc] Extract GPG key import logic into a reusable script * [misc] Add key fingerprint extraction and targeted export for GPG keys * [misc] Remove passphrase from GPG keys before exporting * [misc] Simplify GPG key management by removing import script * [misc] Bump GoReleaser version to v2.14.3 in release workflow * [misc] Replace GPG passphrase variables with NFPM-prefixed alternatives in workflows and configs * [misc] Update naming conventions for package IDs and passphrase variables in workflows and configs * [misc] Standardize NFPM variable naming in release workflow * [misc] Adjust NFPM variable names for consistency in release workflow * [misc] Remove Debian signing GPG key usage in workflows and configs --- .github/workflows/release.yml | 52 ++++++++++++++++++++++++++++++++++- .goreleaser.yaml | 13 +++++---- .goreleaser_ui.yaml | 11 +++++--- 3 files changed, 65 insertions(+), 11 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d1f085b47..7ac5103d9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,7 +10,7 @@ on: env: SIGN_PIPE_VER: "v0.1.1" - GORELEASER_VER: "v2.3.2" + GORELEASER_VER: "v2.14.3" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" @@ -169,6 +169,13 @@ jobs: - name: Install OS build dependencies run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu + - name: Decode GPG signing key + env: + GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }} + run: | + echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc + echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV + - name: Install goversioninfo run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e - name: Generate windows syso amd64 @@ -186,6 +193,24 @@ jobs: HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} + GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }} + NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }} + - name: Verify RPM signatures + run: | + docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c ' + dnf install -y -q rpm-sign curl >/dev/null 2>&1 + curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key + rpm --import /tmp/rpm-pub.key + echo "=== Verifying RPM signatures ===" + for rpm_file in /dist/*amd64*.rpm; do + [ -f "$rpm_file" ] || continue + echo "--- $(basename $rpm_file) ---" + rpm -K "$rpm_file" + done + ' + - name: Clean up GPG key + if: always() + run: rm -f /tmp/gpg-rpm-signing-key.asc - name: Tag and push PR images (amd64 only) if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository run: | @@ -265,6 +290,13 @@ jobs: - name: Install dependencies run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 + - name: Decode GPG signing key + env: + GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }} + run: | + echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc + echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV + - name: Install LLVM-MinGW for ARM64 cross-compilation run: | cd /tmp @@ -289,6 +321,24 @@ jobs: HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} + GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }} + NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }} + - name: Verify RPM signatures + run: | + docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c ' + dnf install -y -q rpm-sign curl >/dev/null 2>&1 + curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key + rpm --import /tmp/rpm-pub.key + echo "=== Verifying RPM signatures ===" + for rpm_file in /dist/*.rpm; do + [ -f "$rpm_file" ] || continue + echo "--- $(basename $rpm_file) ---" + rpm -K "$rpm_file" + done + ' + - name: Clean up GPG key + if: always() + run: rm -f /tmp/gpg-rpm-signing-key.asc - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: diff --git a/.goreleaser.yaml b/.goreleaser.yaml index c0a5efbbe..0f81229cd 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -171,13 +171,12 @@ nfpms: - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ - id: netbird-deb + id: netbird_deb bindir: /usr/bin builds: - netbird formats: - deb - scripts: postinstall: "release_files/post_install.sh" preremove: "release_files/pre_remove.sh" @@ -185,16 +184,18 @@ nfpms: - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ - id: netbird-rpm + id: netbird_rpm bindir: /usr/bin builds: - netbird formats: - rpm - scripts: postinstall: "release_files/post_install.sh" preremove: "release_files/pre_remove.sh" + rpm: + signature: + key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}' dockers: - image_templates: - netbirdio/netbird:{{ .Version }}-amd64 @@ -876,7 +877,7 @@ brews: uploads: - name: debian ids: - - netbird-deb + - netbird_deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com @@ -884,7 +885,7 @@ uploads: - name: yum ids: - - netbird-rpm + - netbird_rpm mode: archive target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} username: dev@wiretrustee.com diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index a243702ea..470f1deaa 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -61,7 +61,7 @@ nfpms: - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ - id: netbird-ui-deb + id: netbird_ui_deb package_name: netbird-ui builds: - netbird-ui @@ -80,7 +80,7 @@ nfpms: - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ - id: netbird-ui-rpm + id: netbird_ui_rpm package_name: netbird-ui builds: - netbird-ui @@ -95,11 +95,14 @@ nfpms: dst: /usr/share/pixmaps/netbird.png dependencies: - netbird + rpm: + signature: + key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}' uploads: - name: debian ids: - - netbird-ui-deb + - netbird_ui_deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com @@ -107,7 +110,7 @@ uploads: - name: yum ids: - - netbird-ui-rpm + - netbird_ui_rpm mode: archive target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} username: dev@wiretrustee.com From f80fe506d5c1f3847015db6223efdf262bb85d77 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 13 Mar 2026 13:22:43 +0100 Subject: [PATCH 16/27] [client] Fix DNS probe thread safety and avoid blocking engine sync (#5576) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix DNS probe thread safety and avoid blocking engine sync Refactor ProbeAvailability to prevent blocking the engine's sync mutex during slow DNS probes. The probe now derives its context from the server's own context (s.ctx) instead of accepting one from the caller, and uses a mutex to ensure only one probe runs at a time — new calls cancel the previous probe before starting. Also fixes a data race in Stop() when accessing probeCancel without the probe mutex. * Ensure DNS probe thread safety by locking critical sections Add proper locking to prevent data races when accessing shared resources during DNS probe execution and Stop(). Update handlers snapshot logic to avoid conflicts with concurrent writers. * Rename context and remove redundant cancellation * Cancel first and lock * Add locking to ensure thread safety when reactivating upstream servers --- client/internal/dns/local/local.go | 2 +- client/internal/dns/server.go | 66 ++++++++++++++++++++++++---- client/internal/dns/server_test.go | 2 +- client/internal/dns/upstream.go | 61 ++++++++++++++++--------- client/internal/dns/upstream_test.go | 2 +- client/internal/engine.go | 3 +- 6 files changed, 101 insertions(+), 35 deletions(-) diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index b374bcc6a..a67a23945 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID { return "local-resolver" } -func (d *Resolver) ProbeAvailability() {} +func (d *Resolver) ProbeAvailability(context.Context) {} // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 179517bbd..6ca4f7957 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -104,12 +104,16 @@ type DefaultServer struct { statusRecorder *peer.Status stateManager *statemanager.Manager + + probeMu sync.Mutex + probeCancel context.CancelFunc + probeWg sync.WaitGroup } type handlerWithStop interface { dns.Handler Stop() - ProbeAvailability() + ProbeAvailability(context.Context) ID() types.HandlerID } @@ -362,7 +366,13 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { + s.probeMu.Lock() + if s.probeCancel != nil { + s.probeCancel() + } s.ctxCancel() + s.probeMu.Unlock() + s.probeWg.Wait() s.shutdownWg.Wait() s.mux.Lock() @@ -479,7 +489,8 @@ func (s *DefaultServer) SearchDomains() []string { } // ProbeAvailability tests each upstream group's servers for availability -// and deactivates the group if no server responds +// and deactivates the group if no server responds. +// If a previous probe is still running, it will be cancelled before starting a new one. func (s *DefaultServer) ProbeAvailability() { if val := os.Getenv(envSkipDNSProbe); val != "" { skipProbe, err := strconv.ParseBool(val) @@ -492,15 +503,52 @@ func (s *DefaultServer) ProbeAvailability() { } } - var wg sync.WaitGroup - for _, mux := range s.dnsMuxMap { - wg.Add(1) - go func(mux handlerWithStop) { - defer wg.Done() - mux.ProbeAvailability() - }(mux.handler) + s.probeMu.Lock() + + // don't start probes on a stopped server + if s.ctx.Err() != nil { + s.probeMu.Unlock() + return } + + // cancel any running probe + if s.probeCancel != nil { + s.probeCancel() + s.probeCancel = nil + } + + // wait for the previous probe goroutines to finish while holding + // the mutex so no other caller can start a new probe concurrently + s.probeWg.Wait() + + // start a new probe + probeCtx, probeCancel := context.WithCancel(s.ctx) + s.probeCancel = probeCancel + + s.probeWg.Add(1) + defer s.probeWg.Done() + + // Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers. + s.mux.Lock() + handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap)) + for _, mux := range s.dnsMuxMap { + handlers = append(handlers, mux.handler) + } + s.mux.Unlock() + + var wg sync.WaitGroup + for _, handler := range handlers { + wg.Add(1) + go func(h handlerWithStop) { + defer wg.Done() + h.ProbeAvailability(probeCtx) + }(handler) + } + + s.probeMu.Unlock() + wg.Wait() + probeCancel() } func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 3606d48b9..d3b0c250d 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -1065,7 +1065,7 @@ type mockHandler struct { func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) Stop() {} -func (m *mockHandler) ProbeAvailability() {} +func (m *mockHandler) ProbeAvailability(context.Context) {} func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 375f6df1c..18128a942 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -65,6 +65,7 @@ type upstreamResolverBase struct { mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration + wg sync.WaitGroup deactivate func(error) reactivate func() @@ -115,6 +116,11 @@ func (u *upstreamResolverBase) MatchSubdomains() bool { func (u *upstreamResolverBase) Stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() + + u.mutex.Lock() + u.wg.Wait() + u.mutex.Unlock() + } // ServeDNS handles a DNS request @@ -260,16 +266,10 @@ func formatFailures(failures []upstreamFailure) string { // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work -func (u *upstreamResolverBase) ProbeAvailability() { +func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) { u.mutex.Lock() defer u.mutex.Unlock() - select { - case <-u.ctx.Done(): - return - default: - } - // avoid probe if upstreams could resolve at least one query if u.successCount.Load() > 0 { return @@ -279,31 +279,39 @@ func (u *upstreamResolverBase) ProbeAvailability() { var mu sync.Mutex var wg sync.WaitGroup - var errors *multierror.Error + var errs *multierror.Error for _, upstream := range u.upstreamServers { - upstream := upstream - wg.Add(1) - go func() { + go func(upstream netip.AddrPort) { defer wg.Done() - err := u.testNameserver(upstream, 500*time.Millisecond) + err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond) if err != nil { - errors = multierror.Append(errors, err) + mu.Lock() + errs = multierror.Append(errs, err) + mu.Unlock() log.Warnf("probing upstream nameserver %s: %s", upstream, err) return } mu.Lock() - defer mu.Unlock() success = true - }() + mu.Unlock() + }(upstream) } wg.Wait() + select { + case <-ctx.Done(): + return + case <-u.ctx.Done(): + return + default: + } + // didn't find a working upstream server, let's disable and try later if !success { - u.disable(errors.ErrorOrNil()) + u.disable(errs.ErrorOrNil()) if u.statusRecorder == nil { return @@ -339,7 +347,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { } for _, upstream := range u.upstreamServers { - if err := u.testNameserver(upstream, probeTimeout); err != nil { + if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil { log.Tracef("upstream check for %s: %s", upstream, err) } else { // at least one upstream server is available, stop probing @@ -364,7 +372,9 @@ func (u *upstreamResolverBase) waitUntilResponse() { log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) u.successCount.Add(1) u.reactivate() + u.mutex.Lock() u.disabled = false + u.mutex.Unlock() } // isTimeout returns true if the given error is a network timeout error. @@ -387,7 +397,11 @@ func (u *upstreamResolverBase) disable(err error) { u.successCount.Store(0) u.deactivate(err) u.disabled = true - go u.waitUntilResponse() + u.wg.Add(1) + go func() { + defer u.wg.Done() + u.waitUntilResponse() + }() } func (u *upstreamResolverBase) upstreamServersString() string { @@ -398,13 +412,18 @@ func (u *upstreamResolverBase) upstreamServersString() string { return strings.Join(servers, ", ") } -func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(u.ctx, timeout) +func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error { + mergedCtx, cancel := context.WithTimeout(baseCtx, timeout) defer cancel() + if externalCtx != nil { + stop2 := context.AfterFunc(externalCtx, cancel) + defer stop2() + } + r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - _, _, err := u.upstreamClient.exchange(ctx, server.String(), r) + _, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r) return err } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 8b06e4475..ab164c30b 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivated = true } - resolver.ProbeAvailability() + resolver.ProbeAvailability(context.TODO()) if !failed { t.Errorf("expected that resolving was deactivated") diff --git a/client/internal/engine.go b/client/internal/engine.go index b0ae841f8..858202155 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1315,8 +1315,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { // Test received (upstream) servers for availability right away instead of upon usage. // If no server of a server group responds this will disable the respective handler and retry later. - e.dnsServer.ProbeAvailability() - + go e.dnsServer.ProbeAvailability() return nil } From d86875aeac88d9a57021df5105cf9ccc2242b1e1 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:01:59 +0100 Subject: [PATCH 17/27] [management] Exclude proxy from peer approval (#5588) --- management/internals/modules/peers/manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index 2f796a5d1..7cb0f3908 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -210,7 +210,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee }, } - _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false) + _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true) if err != nil { return fmt.Errorf("failed to create proxy peer: %w", err) } From 529c0314f84b6344d086b4771f8a1086e61f5e14 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 13 Mar 2026 22:22:02 +0800 Subject: [PATCH 18/27] [client] Fall back to getent/id for SSH user lookup in static builds (#5510) --- client/ssh/server/getent_cgo_unix.go | 24 ++ client/ssh/server/getent_nocgo_unix.go | 74 +++++ client/ssh/server/getent_test.go | 172 ++++++++++ client/ssh/server/getent_unix.go | 122 +++++++ client/ssh/server/getent_unix_test.go | 410 ++++++++++++++++++++++++ client/ssh/server/getent_windows.go | 26 ++ client/ssh/server/shell.go | 10 +- client/ssh/server/user_utils.go | 4 +- client/ssh/server/userswitching_unix.go | 24 +- 9 files changed, 848 insertions(+), 18 deletions(-) create mode 100644 client/ssh/server/getent_cgo_unix.go create mode 100644 client/ssh/server/getent_nocgo_unix.go create mode 100644 client/ssh/server/getent_test.go create mode 100644 client/ssh/server/getent_unix.go create mode 100644 client/ssh/server/getent_unix_test.go create mode 100644 client/ssh/server/getent_windows.go diff --git a/client/ssh/server/getent_cgo_unix.go b/client/ssh/server/getent_cgo_unix.go new file mode 100644 index 000000000..4afbfc627 --- /dev/null +++ b/client/ssh/server/getent_cgo_unix.go @@ -0,0 +1,24 @@ +//go:build cgo && !osusergo && !windows + +package server + +import "os/user" + +// lookupWithGetent with CGO delegates directly to os/user.Lookup. +// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through +// the NSS stack natively. If it fails, the user truly doesn't exist and +// getent would also fail. +func lookupWithGetent(username string) (*user.User, error) { + return user.Lookup(username) +} + +// currentUserWithGetent with CGO delegates directly to os/user.Current. +func currentUserWithGetent() (*user.User, error) { + return user.Current() +} + +// groupIdsWithFallback with CGO delegates directly to user.GroupIds. +// libc's getgrouplist handles NSS groups natively. +func groupIdsWithFallback(u *user.User) ([]string, error) { + return u.GroupIds() +} diff --git a/client/ssh/server/getent_nocgo_unix.go b/client/ssh/server/getent_nocgo_unix.go new file mode 100644 index 000000000..314daae4c --- /dev/null +++ b/client/ssh/server/getent_nocgo_unix.go @@ -0,0 +1,74 @@ +//go:build (!cgo || osusergo) && !windows + +package server + +import ( + "os" + "os/user" + "strconv" + + log "github.com/sirupsen/logrus" +) + +// lookupWithGetent looks up a user by name, falling back to getent if os/user fails. +// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users. +// getent goes through the host's NSS stack. +func lookupWithGetent(username string) (*user.User, error) { + u, err := user.Lookup(username) + if err == nil { + return u, nil + } + + stdErr := err + log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err) + + u, _, getentErr := runGetent(username) + if getentErr != nil { + log.Debugf("getent fallback for %q also failed: %v", username, getentErr) + return nil, stdErr + } + + return u, nil +} + +// currentUserWithGetent gets the current user, falling back to getent if os/user fails. +func currentUserWithGetent() (*user.User, error) { + u, err := user.Current() + if err == nil { + return u, nil + } + + stdErr := err + uid := strconv.Itoa(os.Getuid()) + log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err) + + u, _, getentErr := runGetent(uid) + if getentErr != nil { + return nil, stdErr + } + + return u, nil +} + +// groupIdsWithFallback gets group IDs for a user via the id command first, +// falling back to user.GroupIds(). +// NOTE: unlike lookupWithGetent/currentUserWithGetent which try stdlib first, +// this intentionally tries `id -G` first because without CGO, user.GroupIds() +// only reads /etc/group and silently returns incomplete results for NSS users +// (no error, just missing groups). The id command goes through NSS and returns +// the full set. +func groupIdsWithFallback(u *user.User) ([]string, error) { + ids, err := runIdGroups(u.Username) + if err == nil { + return ids, nil + } + + log.Debugf("id -G %q failed, falling back to user.GroupIds(): %v", u.Username, err) + + ids, stdErr := u.GroupIds() + if stdErr != nil { + return nil, stdErr + } + + return ids, nil +} diff --git a/client/ssh/server/getent_test.go b/client/ssh/server/getent_test.go new file mode 100644 index 000000000..5eac2fdbe --- /dev/null +++ b/client/ssh/server/getent_test.go @@ -0,0 +1,172 @@ +package server + +import ( + "os/user" + "runtime" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLookupWithGetent_CurrentUser(t *testing.T) { + // The current user should always be resolvable on any platform + current, err := user.Current() + require.NoError(t, err) + + u, err := lookupWithGetent(current.Username) + require.NoError(t, err) + assert.Equal(t, current.Username, u.Username) + assert.Equal(t, current.Uid, u.Uid) + assert.Equal(t, current.Gid, u.Gid) +} + +func TestLookupWithGetent_NonexistentUser(t *testing.T) { + _, err := lookupWithGetent("nonexistent_user_xyzzy_12345") + require.Error(t, err, "should fail for nonexistent user") +} + +func TestCurrentUserWithGetent(t *testing.T) { + stdUser, err := user.Current() + require.NoError(t, err) + + u, err := currentUserWithGetent() + require.NoError(t, err) + assert.Equal(t, stdUser.Uid, u.Uid) + assert.Equal(t, stdUser.Username, u.Username) +} + +func TestGroupIdsWithFallback_CurrentUser(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + groups, err := groupIdsWithFallback(current) + require.NoError(t, err) + require.NotEmpty(t, groups, "current user should have at least one group") + + if runtime.GOOS != "windows" { + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be a valid uint32", gid) + } + } +} + +func TestGetShellFromGetent_CurrentUser(t *testing.T) { + if runtime.GOOS == "windows" { + // Windows stub always returns empty, which is correct + shell := getShellFromGetent("1000") + assert.Empty(t, shell, "Windows stub should return empty") + return + } + + current, err := user.Current() + require.NoError(t, err) + + // getent may not be available on all systems (e.g., macOS without Homebrew getent) + shell := getShellFromGetent(current.Uid) + if shell == "" { + t.Log("getShellFromGetent returned empty, getent may not be available") + return + } + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) +} + +func TestLookupWithGetent_RootUser(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("no root user on Windows") + } + + u, err := lookupWithGetent("root") + if err != nil { + t.Skip("root user not available on this system") + } + assert.Equal(t, "0", u.Uid, "root should have UID 0") +} + +// TestIntegration_FullLookupChain exercises the complete user lookup chain +// against the real system, testing that all wrappers (lookupWithGetent, +// currentUserWithGetent, groupIdsWithFallback, getShellFromGetent) produce +// consistent and correct results when composed together. +func TestIntegration_FullLookupChain(t *testing.T) { + // Step 1: currentUserWithGetent must resolve the running user. + current, err := currentUserWithGetent() + require.NoError(t, err, "currentUserWithGetent must resolve the running user") + require.NotEmpty(t, current.Uid) + require.NotEmpty(t, current.Username) + + // Step 2: lookupWithGetent by the same username must return matching identity. + byName, err := lookupWithGetent(current.Username) + require.NoError(t, err) + assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID") + assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID") + assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home") + + // Step 3: groupIdsWithFallback must return at least the primary GID. + groups, err := groupIdsWithFallback(current) + require.NoError(t, err) + require.NotEmpty(t, groups, "user must have at least one group") + + foundPrimary := false + for _, gid := range groups { + if runtime.GOOS != "windows" { + _, err := strconv.ParseUint(gid, 10, 32) + require.NoError(t, err, "group ID %q must be a valid uint32", gid) + } + if gid == current.Gid { + foundPrimary = true + } + } + assert.True(t, foundPrimary, "primary GID %s should appear in supplementary groups", current.Gid) + + // Step 4: getShellFromGetent should either return a valid shell path or empty + // (empty is OK when getent is not available, e.g. macOS without Homebrew getent). + if runtime.GOOS != "windows" { + shell := getShellFromGetent(current.Uid) + if shell != "" { + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) + } + } +} + +// TestIntegration_LookupAndGroupsConsistency verifies that a user resolved via +// lookupWithGetent can have their groups resolved via groupIdsWithFallback, +// testing the handoff between the two functions as used by the SSH server. +func TestIntegration_LookupAndGroupsConsistency(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + // Simulate the SSH server flow: lookup user, then get their groups. + resolved, err := lookupWithGetent(current.Username) + require.NoError(t, err) + + groups, err := groupIdsWithFallback(resolved) + require.NoError(t, err) + require.NotEmpty(t, groups, "resolved user must have groups") + + // On Unix, all returned GIDs must be valid numeric values. + // On Windows, group IDs are SIDs (e.g., "S-1-5-32-544"). + if runtime.GOOS != "windows" { + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be numeric", gid) + } + } +} + +// TestIntegration_ShellLookupChain tests the full shell resolution chain +// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix. +func TestIntegration_ShellLookupChain(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Unix shell lookup not applicable on Windows") + } + + current, err := user.Current() + require.NoError(t, err) + + // getUserShell is the top-level function used by the SSH server. + shell := getUserShell(current.Uid) + require.NotEmpty(t, shell, "getUserShell must always return a shell") + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) +} diff --git a/client/ssh/server/getent_unix.go b/client/ssh/server/getent_unix.go new file mode 100644 index 000000000..18edb2fdf --- /dev/null +++ b/client/ssh/server/getent_unix.go @@ -0,0 +1,122 @@ +//go:build !windows + +package server + +import ( + "context" + "fmt" + "os/exec" + "os/user" + "runtime" + "strings" + "time" +) + +const getentTimeout = 5 * time.Second + +// getShellFromGetent gets a user's login shell via getent by UID. +// This is needed even with CGO because getShellFromPasswd reads /etc/passwd +// directly and won't find NSS-provided users there. +func getShellFromGetent(userID string) string { + _, shell, err := runGetent(userID) + if err != nil { + return "" + } + return shell +} + +// runGetent executes `getent passwd ` and returns the user and login shell. +func runGetent(query string) (*user.User, string, error) { + if !validateGetentInput(query) { + return nil, "", fmt.Errorf("invalid getent input: %q", query) + } + + ctx, cancel := context.WithTimeout(context.Background(), getentTimeout) + defer cancel() + + out, err := exec.CommandContext(ctx, "getent", "passwd", query).Output() + if err != nil { + return nil, "", fmt.Errorf("getent passwd %s: %w", query, err) + } + + return parseGetentPasswd(string(out)) +} + +// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell" +func parseGetentPasswd(output string) (*user.User, string, error) { + fields := strings.SplitN(strings.TrimSpace(output), ":", 8) + if len(fields) < 6 { + return nil, "", fmt.Errorf("unexpected getent output (need 6+ fields): %q", output) + } + + if fields[0] == "" || fields[2] == "" || fields[3] == "" { + return nil, "", fmt.Errorf("missing required fields in getent output: %q", output) + } + + var shell string + if len(fields) >= 7 { + shell = fields[6] + } + + return &user.User{ + Username: fields[0], + Uid: fields[2], + Gid: fields[3], + Name: fields[4], + HomeDir: fields[5], + }, shell, nil +} + +// validateGetentInput checks that the input is safe to pass to getent or id. +// Allows POSIX usernames, numeric UIDs, and common NSS extensions +// (@ for Kerberos, $ for Samba, + for NIS compat). +func validateGetentInput(input string) bool { + maxLen := 32 + if runtime.GOOS == "linux" { + maxLen = 256 + } + + if len(input) == 0 || len(input) > maxLen { + return false + } + + for _, r := range input { + if isAllowedGetentChar(r) { + continue + } + return false + } + return true +} + +func isAllowedGetentChar(r rune) bool { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' { + return true + } + switch r { + case '.', '_', '-', '@', '+', '$': + return true + } + return false +} + +// runIdGroups runs `id -G ` and returns the space-separated group IDs. +func runIdGroups(username string) ([]string, error) { + if !validateGetentInput(username) { + return nil, fmt.Errorf("invalid username for id command: %q", username) + } + + ctx, cancel := context.WithTimeout(context.Background(), getentTimeout) + defer cancel() + + out, err := exec.CommandContext(ctx, "id", "-G", username).Output() + if err != nil { + return nil, fmt.Errorf("id -G %s: %w", username, err) + } + + trimmed := strings.TrimSpace(string(out)) + if trimmed == "" { + return nil, fmt.Errorf("id -G %s: empty output", username) + } + return strings.Fields(trimmed), nil +} diff --git a/client/ssh/server/getent_unix_test.go b/client/ssh/server/getent_unix_test.go new file mode 100644 index 000000000..e44563b79 --- /dev/null +++ b/client/ssh/server/getent_unix_test.go @@ -0,0 +1,410 @@ +//go:build !windows + +package server + +import ( + "os/exec" + "os/user" + "runtime" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseGetentPasswd(t *testing.T) { + tests := []struct { + name string + input string + wantUser *user.User + wantShell string + wantErr bool + errContains string + }{ + { + name: "standard entry", + input: "alice:x:1001:1001:Alice Smith:/home/alice:/bin/bash\n", + wantUser: &user.User{ + Username: "alice", + Uid: "1001", + Gid: "1001", + Name: "Alice Smith", + HomeDir: "/home/alice", + }, + wantShell: "/bin/bash", + }, + { + name: "root entry", + input: "root:x:0:0:root:/root:/bin/bash", + wantUser: &user.User{ + Username: "root", + Uid: "0", + Gid: "0", + Name: "root", + HomeDir: "/root", + }, + wantShell: "/bin/bash", + }, + { + name: "empty gecos field", + input: "svc:x:999:999::/var/lib/svc:/usr/sbin/nologin", + wantUser: &user.User{ + Username: "svc", + Uid: "999", + Gid: "999", + Name: "", + HomeDir: "/var/lib/svc", + }, + wantShell: "/usr/sbin/nologin", + }, + { + name: "gecos with commas", + input: "john:x:1002:1002:John Doe,Room 101,555-1234,555-4321:/home/john:/bin/zsh", + wantUser: &user.User{ + Username: "john", + Uid: "1002", + Gid: "1002", + Name: "John Doe,Room 101,555-1234,555-4321", + HomeDir: "/home/john", + }, + wantShell: "/bin/zsh", + }, + { + name: "remote user with large UID", + input: "remoteuser:*:50001:50001:Remote User:/home/remoteuser:/bin/bash\n", + wantUser: &user.User{ + Username: "remoteuser", + Uid: "50001", + Gid: "50001", + Name: "Remote User", + HomeDir: "/home/remoteuser", + }, + wantShell: "/bin/bash", + }, + { + name: "no shell field (only 6 fields)", + input: "minimal:x:1000:1000::/home/minimal", + wantUser: &user.User{ + Username: "minimal", + Uid: "1000", + Gid: "1000", + Name: "", + HomeDir: "/home/minimal", + }, + wantShell: "", + }, + { + name: "too few fields", + input: "bad:x:1000", + wantErr: true, + errContains: "need 6+ fields", + }, + { + name: "empty username", + input: ":x:1000:1000::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty UID", + input: "test:x::1000::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty GID", + input: "test:x:1000:::/home/test:/bin/bash", + wantErr: true, + errContains: "missing required fields", + }, + { + name: "empty input", + input: "", + wantErr: true, + errContains: "need 6+ fields", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, shell, err := parseGetentPasswd(tt.input) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantUser.Username, u.Username, "username") + assert.Equal(t, tt.wantUser.Uid, u.Uid, "UID") + assert.Equal(t, tt.wantUser.Gid, u.Gid, "GID") + assert.Equal(t, tt.wantUser.Name, u.Name, "name/gecos") + assert.Equal(t, tt.wantUser.HomeDir, u.HomeDir, "home directory") + assert.Equal(t, tt.wantShell, shell, "shell") + }) + } +} + +func TestValidateGetentInput(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"normal username", "alice", true}, + {"numeric UID", "1001", true}, + {"dots and underscores", "alice.bob_test", true}, + {"hyphen", "alice-bob", true}, + {"kerberos principal", "user@REALM", true}, + {"samba machine account", "MACHINE$", true}, + {"NIS compat", "+user", true}, + {"empty", "", false}, + {"null byte", "alice\x00bob", false}, + {"newline", "alice\nbob", false}, + {"tab", "alice\tbob", false}, + {"control char", "alice\x01bob", false}, + {"DEL char", "alice\x7fbob", false}, + {"space rejected", "alice bob", false}, + {"semicolon rejected", "alice;bob", false}, + {"backtick rejected", "alice`bob", false}, + {"pipe rejected", "alice|bob", false}, + {"33 chars exceeds non-linux max", makeLongString(33), runtime.GOOS == "linux"}, + {"256 chars at linux max", makeLongString(256), runtime.GOOS == "linux"}, + {"257 chars exceeds all limits", makeLongString(257), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, validateGetentInput(tt.input)) + }) + } +} + +func makeLongString(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = 'a' + } + return string(b) +} + +func TestRunGetent_RootUser(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + u, shell, err := runGetent("root") + require.NoError(t, err) + assert.Equal(t, "root", u.Username) + assert.Equal(t, "0", u.Uid) + assert.Equal(t, "0", u.Gid) + assert.NotEmpty(t, shell, "root should have a shell") +} + +func TestRunGetent_ByUID(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + u, _, err := runGetent("0") + require.NoError(t, err) + assert.Equal(t, "root", u.Username) + assert.Equal(t, "0", u.Uid) +} + +func TestRunGetent_NonexistentUser(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + _, _, err := runGetent("nonexistent_user_xyzzy_12345") + assert.Error(t, err) +} + +func TestRunGetent_InvalidInput(t *testing.T) { + _, _, err := runGetent("") + assert.Error(t, err) + + _, _, err = runGetent("user\x00name") + assert.Error(t, err) +} + +func TestRunGetent_NotAvailable(t *testing.T) { + if _, err := exec.LookPath("getent"); err == nil { + t.Skip("getent is available, can't test missing case") + } + + _, _, err := runGetent("root") + assert.Error(t, err, "should fail when getent is not installed") +} + +func TestRunIdGroups_CurrentUser(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + groups, err := runIdGroups(current.Username) + require.NoError(t, err) + require.NotEmpty(t, groups, "current user should have at least one group") + + for _, gid := range groups { + _, err := strconv.ParseUint(gid, 10, 32) + assert.NoError(t, err, "group ID %q should be a valid uint32", gid) + } +} + +func TestRunIdGroups_NonexistentUser(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + _, err := runIdGroups("nonexistent_user_xyzzy_12345") + assert.Error(t, err) +} + +func TestRunIdGroups_InvalidInput(t *testing.T) { + _, err := runIdGroups("") + assert.Error(t, err) + + _, err = runIdGroups("user\x00name") + assert.Error(t, err) +} + +func TestGetentResultsMatchStdlib(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + getentUser, _, err := runGetent(current.Username) + require.NoError(t, err) + + assert.Equal(t, current.Username, getentUser.Username, "username should match") + assert.Equal(t, current.Uid, getentUser.Uid, "UID should match") + assert.Equal(t, current.Gid, getentUser.Gid, "GID should match") + assert.Equal(t, current.HomeDir, getentUser.HomeDir, "home directory should match") +} + +func TestGetentResultsMatchStdlib_ByUID(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + getentUser, _, err := runGetent(current.Uid) + require.NoError(t, err) + + assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID") + assert.Equal(t, current.Uid, getentUser.Uid, "UID should match") +} + +func TestIdGroupsMatchStdlib(t *testing.T) { + if _, err := exec.LookPath("id"); err != nil { + t.Skip("id not available on this system") + } + + current, err := user.Current() + require.NoError(t, err) + + stdGroups, err := current.GroupIds() + if err != nil { + t.Skip("os/user.GroupIds() not working, likely CGO_ENABLED=0") + } + + idGroups, err := runIdGroups(current.Username) + require.NoError(t, err) + + // Deduplicate both lists: id -G can return duplicates (e.g., root in Docker) + // and ElementsMatch treats duplicates as distinct. + assert.ElementsMatch(t, uniqueStrings(stdGroups), uniqueStrings(idGroups), "id -G should return same groups as os/user") +} + +func uniqueStrings(ss []string) []string { + seen := make(map[string]struct{}, len(ss)) + out := make([]string, 0, len(ss)) + for _, s := range ss { + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + out = append(out, s) + } + return out +} + +// TestGetShellFromPasswd_CurrentUser verifies that getShellFromPasswd correctly +// reads the current user's shell from /etc/passwd by comparing it against what +// getent reports (which goes through NSS). +func TestGetShellFromPasswd_CurrentUser(t *testing.T) { + current, err := user.Current() + require.NoError(t, err) + + shell := getShellFromPasswd(current.Uid) + if shell == "" { + t.Skip("current user not found in /etc/passwd (may be an NSS-only user)") + } + + assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell) + + if _, err := exec.LookPath("getent"); err == nil { + _, getentShell, getentErr := runGetent(current.Uid) + if getentErr == nil && getentShell != "" { + assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent") + } + } +} + +// TestGetShellFromPasswd_RootUser verifies that getShellFromPasswd can read +// root's shell from /etc/passwd. Root is guaranteed to be in /etc/passwd on +// any standard Unix system. +func TestGetShellFromPasswd_RootUser(t *testing.T) { + shell := getShellFromPasswd("0") + require.NotEmpty(t, shell, "root (UID 0) must be in /etc/passwd") + assert.True(t, shell[0] == '/', "root shell should be an absolute path, got %q", shell) +} + +// TestGetShellFromPasswd_NonexistentUID verifies that getShellFromPasswd +// returns empty for a UID that doesn't exist in /etc/passwd. +func TestGetShellFromPasswd_NonexistentUID(t *testing.T) { + shell := getShellFromPasswd("4294967294") + assert.Empty(t, shell, "nonexistent UID should return empty shell") +} + +// TestGetShellFromPasswd_MatchesGetentForKnownUsers reads /etc/passwd directly +// and cross-validates every entry against getent to ensure parseGetentPasswd +// and getShellFromPasswd agree on shell values. +func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) { + if _, err := exec.LookPath("getent"); err != nil { + t.Skip("getent not available") + } + + // Pick a few well-known system UIDs that are virtually always in /etc/passwd. + uids := []string{"0"} // root + + current, err := user.Current() + require.NoError(t, err) + uids = append(uids, current.Uid) + + for _, uid := range uids { + passwdShell := getShellFromPasswd(uid) + if passwdShell == "" { + continue + } + + _, getentShell, err := runGetent(uid) + if err != nil { + continue + } + + assert.Equal(t, getentShell, passwdShell, "shell mismatch for UID %s", uid) + } +} diff --git a/client/ssh/server/getent_windows.go b/client/ssh/server/getent_windows.go new file mode 100644 index 000000000..3e76b3e8e --- /dev/null +++ b/client/ssh/server/getent_windows.go @@ -0,0 +1,26 @@ +//go:build windows + +package server + +import "os/user" + +// lookupWithGetent on Windows just delegates to os/user.Lookup. +// Windows does not use NSS/getent; its user lookup works without CGO. +func lookupWithGetent(username string) (*user.User, error) { + return user.Lookup(username) +} + +// currentUserWithGetent on Windows just delegates to os/user.Current. +func currentUserWithGetent() (*user.User, error) { + return user.Current() +} + +// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection. +func getShellFromGetent(_ string) string { + return "" +} + +// groupIdsWithFallback on Windows just delegates to u.GroupIds(). +func groupIdsWithFallback(u *user.User) ([]string, error) { + return u.GroupIds() +} diff --git a/client/ssh/server/shell.go b/client/ssh/server/shell.go index fea9d2910..1e8ff5e31 100644 --- a/client/ssh/server/shell.go +++ b/client/ssh/server/shell.go @@ -49,10 +49,14 @@ func getWindowsUserShell() string { return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` } -// getUnixUserShell returns the shell for Unix-like systems +// getUnixUserShell returns the shell for Unix-like systems. +// Tries /etc/passwd first (fast, no subprocess), falls back to getent for NSS users. func getUnixUserShell(userID string) string { - shell := getShellFromPasswd(userID) - if shell != "" { + if shell := getShellFromPasswd(userID); shell != "" { + return shell + } + + if shell := getShellFromGetent(userID); shell != "" { return shell } diff --git a/client/ssh/server/user_utils.go b/client/ssh/server/user_utils.go index 799882cbb..bc2aa2d7d 100644 --- a/client/ssh/server/user_utils.go +++ b/client/ssh/server/user_utils.go @@ -23,8 +23,8 @@ func isPlatformUnix() bool { // Dependency injection variables for testing - allows mocking dynamic runtime checks var ( - getCurrentUser = user.Current - lookupUser = user.Lookup + getCurrentUser = currentUserWithGetent + lookupUser = lookupWithGetent getCurrentOS = func() string { return runtime.GOOS } getIsProcessPrivileged = isCurrentProcessPrivileged diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index d80b77042..220e2240f 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -146,32 +146,30 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u } gid := uint32(gid64) - groups, err := s.getSupplementaryGroups(localUser.Username) - if err != nil { - log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + groups, err := s.getSupplementaryGroups(localUser) + if err != nil || len(groups) == 0 { + if err != nil { + log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + } groups = []uint32{gid} } return uid, gid, groups, nil } -// getSupplementaryGroups retrieves supplementary group IDs for a user -func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { - u, err := user.Lookup(username) +// getSupplementaryGroups retrieves supplementary group IDs for a user. +// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds. +func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) { + groupIDStrings, err := groupIdsWithFallback(u) if err != nil { - return nil, fmt.Errorf("lookup user %s: %w", username, err) - } - - groupIDStrings, err := u.GroupIds() - if err != nil { - return nil, fmt.Errorf("get group IDs for user %s: %w", username, err) + return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err) } groups := make([]uint32, len(groupIDStrings)) for i, gidStr := range groupIDStrings { gid64, err := strconv.ParseUint(gidStr, 10, 32) if err != nil { - return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err) + return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, u.Username, err) } groups[i] = uint32(gid64) } From 2e1aa497d25c907e11d2f1a525f968daac0fb253 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:28:25 +0100 Subject: [PATCH 19/27] [proxy] add log-level flag (#5594) --- proxy/cmd/proxy/cmd/root.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 61ed5871e..60e81feb5 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -34,6 +34,7 @@ var ( ) var ( + logLevel string debugLogs bool mgmtAddr string addr string @@ -69,7 +70,9 @@ var rootCmd = &cobra.Command{ } func init() { + rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", envStringOrDefault("NB_PROXY_LOG_LEVEL", "info"), "Log level: panic, fatal, error, warn, info, debug, trace") rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs") + _ = rootCmd.PersistentFlags().MarkDeprecated("debug", "use --log-level instead") rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to") rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on") rootCmd.Flags().StringVar(&proxyDomain, "domain", envStringOrDefault("NB_PROXY_DOMAIN", ""), "The Domain at which this proxy will be reached. e.g., netbird.example.com") @@ -117,7 +120,7 @@ func runServer(cmd *cobra.Command, args []string) error { return fmt.Errorf("proxy token is required: set %s environment variable", envProxyToken) } - level := "error" + level := logLevel if debugLogs { level = "debug" } From fe9b844511187b3d5433e84456bcf3719f83175c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 13 Mar 2026 17:01:28 +0100 Subject: [PATCH 20/27] [client] refactor auto update workflow (#5448) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Auto-update logic moved out of the UI into a dedicated updatemanager.Manager service that runs in the connection layer. The UI no longer polls or checks for updates independently. The update manager supports three modes driven by the management server's auto-update policy: No policy set by mgm: checks GitHub for the latest version and notifies the user (previous behavior, now centralized) mgm enforces update: the "About" menu triggers installation directly instead of just downloading the file — user still initiates the action mgm forces update: installation proceeds automatically without user interaction updateManager lifecycle is now owned by daemon, giving the daemon server direct control via a new TriggerUpdate RPC Introduces EngineServices struct to group external service dependencies passed to NewEngine, reducing its argument count from 11 to 4 --- client/android/client.go | 4 +- client/cmd/signer/artifactkey.go | 2 +- client/cmd/signer/artifactsign.go | 2 +- client/cmd/signer/revocation.go | 2 +- client/cmd/signer/rootkey.go | 2 +- client/cmd/up.go | 2 +- client/cmd/update_supported.go | 2 +- client/embed/embed.go | 2 +- client/internal/connect.go | 62 +-- client/internal/debug/debug.go | 2 +- client/internal/engine.go | 84 ++-- client/internal/engine_test.go | 55 ++- client/internal/updatemanager/manager_test.go | 214 ---------- .../updatemanager/manager_unsupported.go | 39 -- .../{updatemanager => updater}/doc.go | 4 +- .../downloader/downloader.go | 0 .../downloader/downloader_test.go | 0 .../installer/binary_nowindows.go | 0 .../installer/binary_windows.go | 0 .../installer/doc.go | 0 .../installer/installer.go | 0 .../installer/installer_common.go | 4 +- .../installer/installer_log_darwin.go | 0 .../installer/installer_log_windows.go | 0 .../installer/installer_run_darwin.go | 0 .../installer/installer_run_windows.go | 0 .../installer/log.go | 0 .../installer/procattr_darwin.go | 0 .../installer/procattr_windows.go | 0 .../installer/repourl_dev.go | 0 .../installer/repourl_prod.go | 0 .../installer/result.go | 5 +- .../installer/types.go | 0 .../installer/types_darwin.go | 0 .../installer/types_windows.go | 0 .../{updatemanager => updater}/manager.go | 268 ++++++++++--- client/internal/updater/manager_linux_test.go | 111 ++++++ client/internal/updater/manager_test.go | 227 +++++++++++ .../updater/manager_test_helpers_test.go | 56 +++ .../reposign/artifact.go | 0 .../reposign/artifact_test.go | 0 .../reposign/certs/root-pub.pem | 0 .../reposign/certsdev/root-pub.pem | 0 .../reposign/doc.go | 0 .../reposign/embed_dev.go | 0 .../reposign/embed_prod.go | 0 .../reposign/key.go | 0 .../reposign/key_test.go | 0 .../reposign/revocation.go | 0 .../reposign/revocation_test.go | 0 .../reposign/root.go | 0 .../reposign/root_test.go | 0 .../reposign/signature.go | 0 .../reposign/signature_test.go | 0 .../reposign/verify.go | 2 +- .../reposign/verify_test.go | 0 client/internal/updater/supported_darwin.go | 22 + client/internal/updater/supported_other.go | 7 + client/internal/updater/supported_windows.go | 5 + .../{updatemanager => updater}/update.go | 2 +- client/ios/NetBirdSDK/client.go | 2 +- client/proto/daemon.pb.go | 377 +++++++++++------- client/proto/daemon.proto | 13 +- client/proto/daemon_grpc.pb.go | 40 ++ client/server/event.go | 1 + client/server/server.go | 38 +- client/server/server_connect_test.go | 2 +- client/server/server_test.go | 2 +- client/server/triggerupdate.go | 24 ++ client/server/updateresult.go | 2 +- client/ui/client_ui.go | 44 +- client/ui/event/event.go | 7 +- client/ui/event_handler.go | 39 +- client/ui/profile.go | 6 +- client/ui/quickactions.go | 2 +- .../internals/shared/grpc/conversion.go | 3 +- management/server/account.go | 14 +- management/server/activity/codes.go | 7 + .../handlers/accounts/accounts_handler.go | 4 + .../accounts/accounts_handler_test.go | 6 + management/server/types/settings.go | 5 + shared/management/http/api/openapi.yml | 4 + shared/management/http/api/types.gen.go | 3 + shared/management/proto/management.proto | 4 +- 84 files changed, 1210 insertions(+), 626 deletions(-) delete mode 100644 client/internal/updatemanager/manager_test.go delete mode 100644 client/internal/updatemanager/manager_unsupported.go rename client/internal/{updatemanager => updater}/doc.go (93%) rename client/internal/{updatemanager => updater}/downloader/downloader.go (100%) rename client/internal/{updatemanager => updater}/downloader/downloader_test.go (100%) rename client/internal/{updatemanager => updater}/installer/binary_nowindows.go (100%) rename client/internal/{updatemanager => updater}/installer/binary_windows.go (100%) rename client/internal/{updatemanager => updater}/installer/doc.go (100%) rename client/internal/{updatemanager => updater}/installer/installer.go (100%) rename client/internal/{updatemanager => updater}/installer/installer_common.go (97%) rename client/internal/{updatemanager => updater}/installer/installer_log_darwin.go (100%) rename client/internal/{updatemanager => updater}/installer/installer_log_windows.go (100%) rename client/internal/{updatemanager => updater}/installer/installer_run_darwin.go (100%) rename client/internal/{updatemanager => updater}/installer/installer_run_windows.go (100%) rename client/internal/{updatemanager => updater}/installer/log.go (100%) rename client/internal/{updatemanager => updater}/installer/procattr_darwin.go (100%) rename client/internal/{updatemanager => updater}/installer/procattr_windows.go (100%) rename client/internal/{updatemanager => updater}/installer/repourl_dev.go (100%) rename client/internal/{updatemanager => updater}/installer/repourl_prod.go (100%) rename client/internal/{updatemanager => updater}/installer/result.go (98%) rename client/internal/{updatemanager => updater}/installer/types.go (100%) rename client/internal/{updatemanager => updater}/installer/types_darwin.go (100%) rename client/internal/{updatemanager => updater}/installer/types_windows.go (100%) rename client/internal/{updatemanager => updater}/manager.go (52%) create mode 100644 client/internal/updater/manager_linux_test.go create mode 100644 client/internal/updater/manager_test.go create mode 100644 client/internal/updater/manager_test_helpers_test.go rename client/internal/{updatemanager => updater}/reposign/artifact.go (100%) rename client/internal/{updatemanager => updater}/reposign/artifact_test.go (100%) rename client/internal/{updatemanager => updater}/reposign/certs/root-pub.pem (100%) rename client/internal/{updatemanager => updater}/reposign/certsdev/root-pub.pem (100%) rename client/internal/{updatemanager => updater}/reposign/doc.go (100%) rename client/internal/{updatemanager => updater}/reposign/embed_dev.go (100%) rename client/internal/{updatemanager => updater}/reposign/embed_prod.go (100%) rename client/internal/{updatemanager => updater}/reposign/key.go (100%) rename client/internal/{updatemanager => updater}/reposign/key_test.go (100%) rename client/internal/{updatemanager => updater}/reposign/revocation.go (100%) rename client/internal/{updatemanager => updater}/reposign/revocation_test.go (100%) rename client/internal/{updatemanager => updater}/reposign/root.go (100%) rename client/internal/{updatemanager => updater}/reposign/root_test.go (100%) rename client/internal/{updatemanager => updater}/reposign/signature.go (100%) rename client/internal/{updatemanager => updater}/reposign/signature_test.go (100%) rename client/internal/{updatemanager => updater}/reposign/verify.go (98%) rename client/internal/{updatemanager => updater}/reposign/verify_test.go (100%) create mode 100644 client/internal/updater/supported_darwin.go create mode 100644 client/internal/updater/supported_other.go create mode 100644 client/internal/updater/supported_windows.go rename client/internal/{updatemanager => updater}/update.go (90%) create mode 100644 client/server/triggerupdate.go diff --git a/client/android/client.go b/client/android/client.go index ccf32a90c..3fc571559 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } @@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } diff --git a/client/cmd/signer/artifactkey.go b/client/cmd/signer/artifactkey.go index 5e656650b..ee12326db 100644 --- a/client/cmd/signer/artifactkey.go +++ b/client/cmd/signer/artifactkey.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) var ( diff --git a/client/cmd/signer/artifactsign.go b/client/cmd/signer/artifactsign.go index 881be9367..7c02323dc 100644 --- a/client/cmd/signer/artifactsign.go +++ b/client/cmd/signer/artifactsign.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) const ( diff --git a/client/cmd/signer/revocation.go b/client/cmd/signer/revocation.go index 1d84b65c3..5ff636dcb 100644 --- a/client/cmd/signer/revocation.go +++ b/client/cmd/signer/revocation.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) const ( diff --git a/client/cmd/signer/rootkey.go b/client/cmd/signer/rootkey.go index 78ac36b41..eae0da84d 100644 --- a/client/cmd/signer/rootkey.go +++ b/client/cmd/signer/rootkey.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) var ( diff --git a/client/cmd/up.go b/client/cmd/up.go index 9559287d5..f5766522a 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr r := peer.NewRecorder(config.ManagementURL.String()) r.GetFullStatus() - connectClient := internal.NewConnectClient(ctx, config, r, false) + connectClient := internal.NewConnectClient(ctx, config, r) SetupDebugHandler(ctx, config, r, connectClient, "") return connectClient.Run(nil, util.FindFirstLogPath(logFiles)) diff --git a/client/cmd/update_supported.go b/client/cmd/update_supported.go index 977875093..0b197f4c5 100644 --- a/client/cmd/update_supported.go +++ b/client/cmd/update_supported.go @@ -11,7 +11,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/util" ) diff --git a/client/embed/embed.go b/client/embed/embed.go index 4fbe0eada..21043cf96 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -202,7 +202,7 @@ func (c *Client) Start(startCtx context.Context) error { if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } - client := internal.NewConnectClient(ctx, c.config, c.recorder, false) + client := internal.NewConnectClient(ctx, c.config, c.recorder) client.SetSyncResponsePersistence(true) // either startup error (permanent backoff err) or nil err (successful engine up) diff --git a/client/internal/connect.go b/client/internal/connect.go index 68a0cb8da..ccd7b6c33 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -27,8 +27,8 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/updatemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater" + "github.com/netbirdio/netbird/client/internal/updater/installer" nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" @@ -44,13 +44,13 @@ import ( ) type ConnectClient struct { - ctx context.Context - config *profilemanager.Config - statusRecorder *peer.Status - doInitialAutoUpdate bool + ctx context.Context + config *profilemanager.Config + statusRecorder *peer.Status - engine *Engine - engineMutex sync.Mutex + engine *Engine + engineMutex sync.Mutex + updateManager *updater.Manager persistSyncResponse bool } @@ -59,17 +59,19 @@ func NewConnectClient( ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - doInitalAutoUpdate bool, ) *ConnectClient { return &ConnectClient{ - ctx: ctx, - config: config, - statusRecorder: statusRecorder, - doInitialAutoUpdate: doInitalAutoUpdate, - engineMutex: sync.Mutex{}, + ctx: ctx, + config: config, + statusRecorder: statusRecorder, + engineMutex: sync.Mutex{}, } } +func (c *ConnectClient) SetUpdateManager(um *updater.Manager) { + c.updateManager = um +} + // Run with main logic. func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error { return c.run(MobileDependency{}, runningChan, logPath) @@ -187,14 +189,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan stateManager := statemanager.New(path) stateManager.RegisterState(&sshconfig.ShutdownState{}) - updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager) - if err == nil { - updateManager.CheckUpdateSuccess(c.ctx) + if c.updateManager != nil { + c.updateManager.CheckUpdateSuccess(c.ctx) + } - inst := installer.New() - if err := inst.CleanUpInstallerFiles(); err != nil { - log.Errorf("failed to clean up temporary installer file: %v", err) - } + inst := installer.New() + if err := inst.CleanUpInstallerFiles(); err != nil { + log.Errorf("failed to clean up temporary installer file: %v", err) } defer c.statusRecorder.ClientStop() @@ -308,7 +309,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager) + engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{ + SignalClient: signalClient, + MgmClient: mgmClient, + RelayManager: relayManager, + StatusRecorder: c.statusRecorder, + Checks: checks, + StateManager: stateManager, + UpdateManager: c.updateManager, + }, mobileDependency) engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine = engine c.engineMutex.Unlock() @@ -318,15 +327,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return wrapErr(err) } - if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil { - // AutoUpdate will be true when the user click on "Connect" menu on the UI - if c.doInitialAutoUpdate { - log.Infof("start engine by ui, run auto-update check") - c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate) - c.doInitialAutoUpdate = false - } - } - log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 0f8243e7a..f0f399bef 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -27,7 +27,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" diff --git a/client/internal/engine.go b/client/internal/engine.go index 858202155..fd3bdf7af 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -51,7 +51,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager" + "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/jobexec" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" @@ -79,7 +79,6 @@ const ( var ErrResetConnection = fmt.Errorf("reset connection") -// EngineConfig is a config for the Engine type EngineConfig struct { WgPort int WgIfaceName string @@ -141,6 +140,17 @@ type EngineConfig struct { LogPath string } +// EngineServices holds the external service dependencies required by the Engine. +type EngineServices struct { + SignalClient signal.Client + MgmClient mgm.Client + RelayManager *relayClient.Manager + StatusRecorder *peer.Status + Checks []*mgmProto.Checks + StateManager *statemanager.Manager + UpdateManager *updater.Manager +} + // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. type Engine struct { // signal is a Signal Service client @@ -209,7 +219,7 @@ type Engine struct { flowManager nftypes.FlowManager // auto-update - updateManager *updatemanager.Manager + updateManager *updater.Manager // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor @@ -239,22 +249,17 @@ type localIpUpdater interface { func NewEngine( clientCtx context.Context, clientCancel context.CancelFunc, - signalClient signal.Client, - mgmClient mgm.Client, - relayManager *relayClient.Manager, config *EngineConfig, + services EngineServices, mobileDep MobileDependency, - statusRecorder *peer.Status, - checks []*mgmProto.Checks, - stateManager *statemanager.Manager, ) *Engine { engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, - signal: signalClient, - signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), - mgmClient: mgmClient, - relayManager: relayManager, + signal: services.SignalClient, + signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), + mgmClient: services.MgmClient, + relayManager: services.RelayManager, peerStore: peerstore.NewConnStore(), syncMsgMux: &sync.Mutex{}, config: config, @@ -262,11 +267,12 @@ func NewEngine( STUNs: []*stun.URI{}, TURNs: []*stun.URI{}, networkSerial: 0, - statusRecorder: statusRecorder, - stateManager: stateManager, - checks: checks, + statusRecorder: services.StatusRecorder, + stateManager: services.StateManager, + checks: services.Checks, probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), jobExecutor: jobexec.NewExecutor(), + updateManager: services.UpdateManager, } log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) @@ -309,7 +315,7 @@ func (e *Engine) Stop() error { } if e.updateManager != nil { - e.updateManager.Stop() + e.updateManager.SetDownloadOnly() } log.Info("cleaning up status recorder states") @@ -559,13 +565,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return nil } -func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) { - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() - - e.handleAutoUpdateVersion(autoUpdateSettings, true) -} - func (e *Engine) createFirewall() error { if e.config.DisableFirewall { log.Infof("firewall is disabled") @@ -793,39 +792,22 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg return nil } -func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) { +func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) { + if e.updateManager == nil { + return + } + if autoUpdateSettings == nil { return } - disabled := autoUpdateSettings.Version == disableAutoUpdate - - // stop and cleanup if disabled - if e.updateManager != nil && disabled { - log.Infof("auto-update is disabled, stopping update manager") - e.updateManager.Stop() - e.updateManager = nil + if autoUpdateSettings.Version == disableAutoUpdate { + log.Infof("auto-update is disabled") + e.updateManager.SetDownloadOnly() return } - // Skip check unless AlwaysUpdate is enabled or this is the initial check at startup - if !autoUpdateSettings.AlwaysUpdate && !initialCheck { - log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check") - return - } - - // Start manager if needed - if e.updateManager == nil { - log.Infof("starting auto-update manager") - updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager) - if err != nil { - return - } - e.updateManager = updateManager - e.updateManager.Start(e.ctx) - } - log.Infof("handling auto-update version: %s", autoUpdateSettings.Version) - e.updateManager.SetVersion(autoUpdateSettings.Version) + e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate) } func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { @@ -842,7 +824,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { } if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { - e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false) + e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate) } if update.GetNetbirdConfig() != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 012c8ad6e..f9e7f8fa0 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -251,9 +251,6 @@ func TestEngine_SSH(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine( ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, &EngineConfig{ WgIfaceName: "utun101", WgAddr: "100.64.0.1/24", @@ -263,10 +260,13 @@ func TestEngine_SSH(t *testing.T) { MTU: iface.DefaultMTU, SSHKey: sshKey, }, + EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil, - nil, ) engine.dnsServer = &dns.MockServer{ @@ -428,13 +428,18 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun102", WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -647,13 +652,18 @@ func TestEngine_Sync(t *testing.T) { return nil } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun103", WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{SyncFunc: syncFunc}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -812,13 +822,18 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { wgAddr := fmt.Sprintf("100.66.%d.1/24", n) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { @@ -1014,13 +1029,18 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { wgAddr := fmt.Sprintf("100.66.%d.1/24", n) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + }, EngineServices{ + SignalClient: &signal.MockClient{}, + MgmClient: &mgmt.MockClient{}, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) @@ -1546,7 +1566,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil + e, err := NewEngine(ctx, cancel, conf, EngineServices{ + SignalClient: signalClient, + MgmClient: mgmtClient, + RelayManager: relayMgr, + StatusRecorder: peer.NewRecorder("https://mgm"), + }, MobileDependency{}), nil e.ctx = ctx return e, err } diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go deleted file mode 100644 index 20ddec10d..000000000 --- a/client/internal/updatemanager/manager_test.go +++ /dev/null @@ -1,214 +0,0 @@ -//go:build windows || darwin - -package updatemanager - -import ( - "context" - "fmt" - "path" - "testing" - "time" - - v "github.com/hashicorp/go-version" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/statemanager" -) - -type versionUpdateMock struct { - latestVersion *v.Version - onUpdate func() -} - -func (v versionUpdateMock) StopWatch() {} - -func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool { - return false -} - -func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) { - v.onUpdate = updateFn -} - -func (v versionUpdateMock) LatestVersion() *v.Version { - return v.latestVersion -} - -func (v versionUpdateMock) StartFetcher() {} - -func Test_LatestVersion(t *testing.T) { - testMatrix := []struct { - name string - daemonVersion string - initialLatestVersion *v.Version - latestVersion *v.Version - shouldUpdateInit bool - shouldUpdateLater bool - }{ - { - name: "Should only trigger update once due to time between triggers being < 5 Minutes", - daemonVersion: "1.0.0", - initialLatestVersion: v.Must(v.NewSemver("1.0.1")), - latestVersion: v.Must(v.NewSemver("1.0.2")), - shouldUpdateInit: true, - shouldUpdateLater: false, - }, - { - name: "Shouldn't update initially, but should update as soon as latest version is fetched", - daemonVersion: "1.0.0", - initialLatestVersion: nil, - latestVersion: v.Must(v.NewSemver("1.0.1")), - shouldUpdateInit: false, - shouldUpdateLater: true, - }, - } - - for idx, c := range testMatrix { - mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} - tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) - m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) - m.update = mockUpdate - - targetVersionChan := make(chan string, 1) - - m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { - targetVersionChan <- targetVersion - return nil - } - m.currentVersion = c.daemonVersion - m.Start(context.Background()) - m.SetVersion("latest") - var triggeredInit bool - select { - case targetVersion := <-targetVersionChan: - if targetVersion != c.initialLatestVersion.String() { - t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion) - } - triggeredInit = true - case <-time.After(10 * time.Millisecond): - triggeredInit = false - } - if triggeredInit != c.shouldUpdateInit { - t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) - } - - mockUpdate.latestVersion = c.latestVersion - mockUpdate.onUpdate() - - var triggeredLater bool - select { - case targetVersion := <-targetVersionChan: - if targetVersion != c.latestVersion.String() { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) - } - triggeredLater = true - case <-time.After(10 * time.Millisecond): - triggeredLater = false - } - if triggeredLater != c.shouldUpdateLater { - t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) - } - - m.Stop() - } -} - -func Test_HandleUpdate(t *testing.T) { - testMatrix := []struct { - name string - daemonVersion string - latestVersion *v.Version - expectedVersion string - shouldUpdate bool - }{ - { - name: "Update to a specific version should update regardless of if latestVersion is available yet", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.56.0", - shouldUpdate: true, - }, - { - name: "Update to specific version should not update if version matches", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.55.0", - shouldUpdate: false, - }, - { - name: "Update to specific version should not update if current version is newer", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "0.54.0", - shouldUpdate: false, - }, - { - name: "Update to latest version should update if latest is newer", - daemonVersion: "0.55.0", - latestVersion: v.Must(v.NewSemver("0.56.0")), - expectedVersion: "latest", - shouldUpdate: true, - }, - { - name: "Update to latest version should not update if latest == current", - daemonVersion: "0.56.0", - latestVersion: v.Must(v.NewSemver("0.56.0")), - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if daemon version is invalid", - daemonVersion: "development", - latestVersion: v.Must(v.NewSemver("1.0.0")), - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if expecting latest and latest version is unavailable", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "latest", - shouldUpdate: false, - }, - { - name: "Should not update if expected version is invalid", - daemonVersion: "0.55.0", - latestVersion: nil, - expectedVersion: "development", - shouldUpdate: false, - }, - } - for idx, c := range testMatrix { - tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) - m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) - m.update = &versionUpdateMock{latestVersion: c.latestVersion} - targetVersionChan := make(chan string, 1) - - m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { - targetVersionChan <- targetVersion - return nil - } - - m.currentVersion = c.daemonVersion - m.Start(context.Background()) - m.SetVersion(c.expectedVersion) - - var updateTriggered bool - select { - case targetVersion := <-targetVersionChan: - if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) - } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion { - t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion) - } - updateTriggered = true - case <-time.After(10 * time.Millisecond): - updateTriggered = false - } - - if updateTriggered != c.shouldUpdate { - t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) - } - m.Stop() - } -} diff --git a/client/internal/updatemanager/manager_unsupported.go b/client/internal/updatemanager/manager_unsupported.go deleted file mode 100644 index 4e87c2d77..000000000 --- a/client/internal/updatemanager/manager_unsupported.go +++ /dev/null @@ -1,39 +0,0 @@ -//go:build !windows && !darwin - -package updatemanager - -import ( - "context" - "fmt" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/statemanager" -) - -// Manager is a no-op stub for unsupported platforms -type Manager struct{} - -// NewManager returns a no-op manager for unsupported platforms -func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { - return nil, fmt.Errorf("update manager is not supported on this platform") -} - -// CheckUpdateSuccess is a no-op on unsupported platforms -func (m *Manager) CheckUpdateSuccess(ctx context.Context) { - // no-op -} - -// Start is a no-op on unsupported platforms -func (m *Manager) Start(ctx context.Context) { - // no-op -} - -// SetVersion is a no-op on unsupported platforms -func (m *Manager) SetVersion(expectedVersion string) { - // no-op -} - -// Stop is a no-op on unsupported platforms -func (m *Manager) Stop() { - // no-op -} diff --git a/client/internal/updatemanager/doc.go b/client/internal/updater/doc.go similarity index 93% rename from client/internal/updatemanager/doc.go rename to client/internal/updater/doc.go index 54d1bdeab..e1924aa43 100644 --- a/client/internal/updatemanager/doc.go +++ b/client/internal/updater/doc.go @@ -1,4 +1,4 @@ -// Package updatemanager provides automatic update management for the NetBird client. +// Package updater provides automatic update management for the NetBird client. // It monitors for new versions, handles update triggers from management server directives, // and orchestrates the download and installation of client updates. // @@ -32,4 +32,4 @@ // // This enables verification of successful updates and appropriate user notification // after the client restarts with the new version. -package updatemanager +package updater diff --git a/client/internal/updatemanager/downloader/downloader.go b/client/internal/updater/downloader/downloader.go similarity index 100% rename from client/internal/updatemanager/downloader/downloader.go rename to client/internal/updater/downloader/downloader.go diff --git a/client/internal/updatemanager/downloader/downloader_test.go b/client/internal/updater/downloader/downloader_test.go similarity index 100% rename from client/internal/updatemanager/downloader/downloader_test.go rename to client/internal/updater/downloader/downloader_test.go diff --git a/client/internal/updatemanager/installer/binary_nowindows.go b/client/internal/updater/installer/binary_nowindows.go similarity index 100% rename from client/internal/updatemanager/installer/binary_nowindows.go rename to client/internal/updater/installer/binary_nowindows.go diff --git a/client/internal/updatemanager/installer/binary_windows.go b/client/internal/updater/installer/binary_windows.go similarity index 100% rename from client/internal/updatemanager/installer/binary_windows.go rename to client/internal/updater/installer/binary_windows.go diff --git a/client/internal/updatemanager/installer/doc.go b/client/internal/updater/installer/doc.go similarity index 100% rename from client/internal/updatemanager/installer/doc.go rename to client/internal/updater/installer/doc.go diff --git a/client/internal/updatemanager/installer/installer.go b/client/internal/updater/installer/installer.go similarity index 100% rename from client/internal/updatemanager/installer/installer.go rename to client/internal/updater/installer/installer.go diff --git a/client/internal/updatemanager/installer/installer_common.go b/client/internal/updater/installer/installer_common.go similarity index 97% rename from client/internal/updatemanager/installer/installer_common.go rename to client/internal/updater/installer/installer_common.go index 03378d55f..8e44bee82 100644 --- a/client/internal/updatemanager/installer/installer_common.go +++ b/client/internal/updater/installer/installer_common.go @@ -16,8 +16,8 @@ import ( goversion "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" - "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" + "github.com/netbirdio/netbird/client/internal/updater/downloader" + "github.com/netbirdio/netbird/client/internal/updater/reposign" ) type Installer struct { diff --git a/client/internal/updatemanager/installer/installer_log_darwin.go b/client/internal/updater/installer/installer_log_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/installer_log_darwin.go rename to client/internal/updater/installer/installer_log_darwin.go diff --git a/client/internal/updatemanager/installer/installer_log_windows.go b/client/internal/updater/installer/installer_log_windows.go similarity index 100% rename from client/internal/updatemanager/installer/installer_log_windows.go rename to client/internal/updater/installer/installer_log_windows.go diff --git a/client/internal/updatemanager/installer/installer_run_darwin.go b/client/internal/updater/installer/installer_run_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/installer_run_darwin.go rename to client/internal/updater/installer/installer_run_darwin.go diff --git a/client/internal/updatemanager/installer/installer_run_windows.go b/client/internal/updater/installer/installer_run_windows.go similarity index 100% rename from client/internal/updatemanager/installer/installer_run_windows.go rename to client/internal/updater/installer/installer_run_windows.go diff --git a/client/internal/updatemanager/installer/log.go b/client/internal/updater/installer/log.go similarity index 100% rename from client/internal/updatemanager/installer/log.go rename to client/internal/updater/installer/log.go diff --git a/client/internal/updatemanager/installer/procattr_darwin.go b/client/internal/updater/installer/procattr_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/procattr_darwin.go rename to client/internal/updater/installer/procattr_darwin.go diff --git a/client/internal/updatemanager/installer/procattr_windows.go b/client/internal/updater/installer/procattr_windows.go similarity index 100% rename from client/internal/updatemanager/installer/procattr_windows.go rename to client/internal/updater/installer/procattr_windows.go diff --git a/client/internal/updatemanager/installer/repourl_dev.go b/client/internal/updater/installer/repourl_dev.go similarity index 100% rename from client/internal/updatemanager/installer/repourl_dev.go rename to client/internal/updater/installer/repourl_dev.go diff --git a/client/internal/updatemanager/installer/repourl_prod.go b/client/internal/updater/installer/repourl_prod.go similarity index 100% rename from client/internal/updatemanager/installer/repourl_prod.go rename to client/internal/updater/installer/repourl_prod.go diff --git a/client/internal/updatemanager/installer/result.go b/client/internal/updater/installer/result.go similarity index 98% rename from client/internal/updatemanager/installer/result.go rename to client/internal/updater/installer/result.go index 03d08d527..526c3eb53 100644 --- a/client/internal/updatemanager/installer/result.go +++ b/client/internal/updater/installer/result.go @@ -203,7 +203,10 @@ func (rh *ResultHandler) write(result Result) error { func (rh *ResultHandler) cleanup() error { err := os.Remove(rh.resultFile) - if err != nil && !os.IsNotExist(err) { + if err != nil { + if os.IsNotExist(err) { + return nil + } return err } log.Debugf("delete installer result file: %s", rh.resultFile) diff --git a/client/internal/updatemanager/installer/types.go b/client/internal/updater/installer/types.go similarity index 100% rename from client/internal/updatemanager/installer/types.go rename to client/internal/updater/installer/types.go diff --git a/client/internal/updatemanager/installer/types_darwin.go b/client/internal/updater/installer/types_darwin.go similarity index 100% rename from client/internal/updatemanager/installer/types_darwin.go rename to client/internal/updater/installer/types_darwin.go diff --git a/client/internal/updatemanager/installer/types_windows.go b/client/internal/updater/installer/types_windows.go similarity index 100% rename from client/internal/updatemanager/installer/types_windows.go rename to client/internal/updater/installer/types_windows.go diff --git a/client/internal/updatemanager/manager.go b/client/internal/updater/manager.go similarity index 52% rename from client/internal/updatemanager/manager.go rename to client/internal/updater/manager.go index eae11de56..dfcb93177 100644 --- a/client/internal/updatemanager/manager.go +++ b/client/internal/updater/manager.go @@ -1,12 +1,9 @@ -//go:build windows || darwin - -package updatemanager +package updater import ( "context" "errors" "fmt" - "runtime" "sync" "time" @@ -15,7 +12,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" ) @@ -41,6 +38,9 @@ type Manager struct { statusRecorder *peer.Status stateManager *statemanager.Manager + downloadOnly bool // true when no enforcement from management; notifies UI to download latest + forceUpdate bool // true when management sets AlwaysUpdate; skips UI interaction and installs directly + lastTrigger time.Time mgmUpdateChan chan struct{} updateChannel chan struct{} @@ -53,37 +53,38 @@ type Manager struct { expectedVersion *v.Version updateToLatestVersion bool - // updateMutex protect update and expectedVersion fields + pendingVersion *v.Version + + // updateMutex protects update, expectedVersion, updateToLatestVersion, + // downloadOnly, forceUpdate, pendingVersion, and lastTrigger fields updateMutex sync.Mutex - triggerUpdateFn func(context.Context, string) error + // installMutex and installing guard against concurrent installation attempts + installMutex sync.Mutex + installing bool + + // protect to start the service multiple times + mu sync.Mutex + + autoUpdateSupported func() bool } -func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { - if runtime.GOOS == "darwin" { - isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable() - if isBrew { - log.Warnf("auto-update disabled on Home Brew installation") - return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet") - } - } - return newManager(statusRecorder, stateManager) -} - -func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { +// NewManager creates a new update manager. The manager is single-use: once Stop() is called, it cannot be restarted. +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *Manager { manager := &Manager{ - statusRecorder: statusRecorder, - stateManager: stateManager, - mgmUpdateChan: make(chan struct{}, 1), - updateChannel: make(chan struct{}, 1), - currentVersion: version.NetbirdVersion(), - update: version.NewUpdate("nb/client"), + statusRecorder: statusRecorder, + stateManager: stateManager, + mgmUpdateChan: make(chan struct{}, 1), + updateChannel: make(chan struct{}, 1), + currentVersion: version.NetbirdVersion(), + update: version.NewUpdate("nb/client"), + downloadOnly: true, + autoUpdateSupported: isAutoUpdateSupported, } - manager.triggerUpdateFn = manager.triggerUpdate stateManager.RegisterState(&UpdateState{}) - return manager, nil + return manager } // CheckUpdateSuccess checks if the update was successful and send a notification. @@ -124,8 +125,10 @@ func (m *Manager) CheckUpdateSuccess(ctx context.Context) { } func (m *Manager) Start(ctx context.Context) { + log.Infof("starting update manager") + m.mu.Lock() + defer m.mu.Unlock() if m.cancel != nil { - log.Errorf("Manager already started") return } @@ -142,13 +145,32 @@ func (m *Manager) Start(ctx context.Context) { m.cancel = cancel m.wg.Add(1) - go m.updateLoop(ctx) + go func() { + defer m.wg.Done() + m.updateLoop(ctx) + }() } -func (m *Manager) SetVersion(expectedVersion string) { - log.Infof("set expected agent version for upgrade: %s", expectedVersion) - if m.cancel == nil { - log.Errorf("manager not started") +func (m *Manager) SetDownloadOnly() { + m.updateMutex.Lock() + m.downloadOnly = true + m.forceUpdate = false + m.expectedVersion = nil + m.updateToLatestVersion = false + m.lastTrigger = time.Time{} + m.updateMutex.Unlock() + + select { + case m.mgmUpdateChan <- struct{}{}: + default: + } +} + +func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) { + log.Infof("expected version changed to %s, force update: %t", expectedVersion, forceUpdate) + + if !m.autoUpdateSupported() { + log.Warnf("auto-update not supported on this platform") return } @@ -159,6 +181,7 @@ func (m *Manager) SetVersion(expectedVersion string) { log.Errorf("empty expected version provided") m.expectedVersion = nil m.updateToLatestVersion = false + m.downloadOnly = true return } @@ -178,12 +201,97 @@ func (m *Manager) SetVersion(expectedVersion string) { m.updateToLatestVersion = false } + m.lastTrigger = time.Time{} + m.downloadOnly = false + m.forceUpdate = forceUpdate + select { case m.mgmUpdateChan <- struct{}{}: default: } } +// Install triggers the installation of the pending version. It is called when the user clicks the install button in the UI. +func (m *Manager) Install(ctx context.Context) error { + if !m.autoUpdateSupported() { + return fmt.Errorf("auto-update not supported on this platform") + } + + m.updateMutex.Lock() + pending := m.pendingVersion + m.updateMutex.Unlock() + + if pending == nil { + return fmt.Errorf("no pending version to install") + } + + return m.tryInstall(ctx, pending) +} + +// tryInstall ensures only one installation runs at a time. Concurrent callers +// receive an error immediately rather than queuing behind a running install. +func (m *Manager) tryInstall(ctx context.Context, targetVersion *v.Version) error { + m.installMutex.Lock() + if m.installing { + m.installMutex.Unlock() + return fmt.Errorf("installation already in progress") + } + m.installing = true + m.installMutex.Unlock() + + defer func() { + m.installMutex.Lock() + m.installing = false + m.installMutex.Unlock() + }() + + return m.install(ctx, targetVersion) +} + +// NotifyUI re-publishes the current update state to a newly connected UI client. +// Only needed for download-only mode where the latest version is already cached +// NotifyUI re-publishes the current update state so a newly connected UI gets the info. +func (m *Manager) NotifyUI() { + m.updateMutex.Lock() + if m.update == nil { + m.updateMutex.Unlock() + return + } + downloadOnly := m.downloadOnly + pendingVersion := m.pendingVersion + latestVersion := m.update.LatestVersion() + m.updateMutex.Unlock() + + if downloadOnly { + if latestVersion == nil { + return + } + currentVersion, err := v.NewVersion(m.currentVersion) + if err != nil || currentVersion.GreaterThanOrEqual(latestVersion) { + return + } + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": latestVersion.String()}, + ) + return + } + + if pendingVersion != nil { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": pendingVersion.String(), "enforced": "true"}, + ) + } +} + +// Stop is not used at the moment because it fully depends on the daemon. In a future refactor it may make sense to use it. func (m *Manager) Stop() { if m.cancel == nil { return @@ -214,8 +322,6 @@ func (m *Manager) onContextCancel() { } func (m *Manager) updateLoop(ctx context.Context) { - defer m.wg.Done() - for { select { case <-ctx.Done(): @@ -239,55 +345,89 @@ func (m *Manager) handleUpdate(ctx context.Context) { return } - expectedVersion := m.expectedVersion - useLatest := m.updateToLatestVersion + downloadOnly := m.downloadOnly + forceUpdate := m.forceUpdate curLatestVersion := m.update.LatestVersion() - m.updateMutex.Unlock() switch { - // Resolve "latest" to actual version - case useLatest: + // Download-only mode or resolve "latest" to actual version + case downloadOnly, m.updateToLatestVersion: if curLatestVersion == nil { log.Tracef("latest version not fetched yet") + m.updateMutex.Unlock() return } updateVersion = curLatestVersion - // Update to specific version - case expectedVersion != nil: - updateVersion = expectedVersion + // Install to specific version + case m.expectedVersion != nil: + updateVersion = m.expectedVersion default: log.Debugf("no expected version information set") + m.updateMutex.Unlock() return } log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion) - if !m.shouldUpdate(updateVersion) { + if !m.shouldUpdate(updateVersion, forceUpdate) { + m.updateMutex.Unlock() return } m.lastTrigger = time.Now() - log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion) - m.statusRecorder.PublishEvent( - cProto.SystemEvent_CRITICAL, - cProto.SystemEvent_SYSTEM, - "Automatically updating client", - "Your client version is older than auto-update version set in Management, updating client now.", - nil, - ) + log.Infof("new version available: %s", updateVersion) + + if !downloadOnly && !forceUpdate { + m.pendingVersion = updateVersion + } + m.updateMutex.Unlock() + + if downloadOnly { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": updateVersion.String()}, + ) + return + } + + if forceUpdate { + if err := m.tryInstall(ctx, updateVersion); err != nil { + log.Errorf("force update failed: %v", err) + } + return + } + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "New version available", + "", + map[string]string{"new_version_available": updateVersion.String(), "enforced": "true"}, + ) +} + +func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "Updating client", + "Installing update now.", + nil, + ) m.statusRecorder.PublishEvent( cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM, "", "", - map[string]string{"progress_window": "show", "version": updateVersion.String()}, + map[string]string{"progress_window": "show", "version": pendingVersion.String()}, ) updateState := UpdateState{ PreUpdateVersion: m.currentVersion, - TargetVersion: updateVersion.String(), + TargetVersion: pendingVersion.String(), } - if err := m.stateManager.UpdateState(updateState); err != nil { log.Warnf("failed to update state: %v", err) } else { @@ -296,8 +436,9 @@ func (m *Manager) handleUpdate(ctx context.Context) { } } - if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil { - log.Errorf("Error triggering auto-update: %v", err) + inst := installer.New() + if err := inst.RunInstallation(ctx, pendingVersion.String()); err != nil { + log.Errorf("error triggering update: %v", err) m.statusRecorder.PublishEvent( cProto.SystemEvent_ERROR, cProto.SystemEvent_SYSTEM, @@ -305,7 +446,9 @@ func (m *Manager) handleUpdate(ctx context.Context) { fmt.Sprintf("Auto-update failed: %v", err), nil, ) + return err } + return nil } // loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it. @@ -339,7 +482,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e return updateState, nil } -func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { +func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool { if m.currentVersion == developmentVersion { log.Debugf("skipping auto-update, running development version") return false @@ -354,8 +497,8 @@ func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { return false } - if time.Since(m.lastTrigger) < 5*time.Minute { - log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) + if forceUpdate && time.Since(m.lastTrigger) < 3*time.Minute { + log.Infof("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) return false } @@ -367,8 +510,3 @@ func (m *Manager) lastResultErrReason() string { result := installer.NewResultHandler(inst.TempDir()) return result.GetErrorResultReason() } - -func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error { - inst := installer.New() - return inst.RunInstallation(ctx, targetVersion) -} diff --git a/client/internal/updater/manager_linux_test.go b/client/internal/updater/manager_linux_test.go new file mode 100644 index 000000000..b05dd7e7d --- /dev/null +++ b/client/internal/updater/manager_linux_test.go @@ -0,0 +1,111 @@ +//go:build !windows && !darwin + +package updater + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// On Linux, only Mode 1 (downloadOnly) is supported. +// SetVersion is a no-op because auto-update installation is not supported. + +func Test_LatestVersion_Linux(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should notify again when a newer version arrives even within 5 minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: true, + }, + { + name: "Shouldn't notify initially, but should notify as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = mockUpdate + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetDownloadOnly() + + ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredInit := ver != "" + if enforced { + t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name) + } + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial notify mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() { + t.Errorf("%s: Initial version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + ver, enforced = waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredLater := ver != "" + if enforced { + t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name) + } + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Later notify mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Later version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } + + m.Stop() + } +} + +func Test_SetVersion_NoOp_Linux(t *testing.T) { + // On Linux, SetVersion should be a no-op — no events fired + tmpFile := path.Join(t.TempDir(), "update-test-noop.json") + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))} + m.currentVersion = "1.0.0" + m.Start(context.Background()) + m.SetVersion("1.0.1", false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + if ver != "" { + t.Errorf("SetVersion should be a no-op on Linux, but got event with version %s", ver) + } + + m.Stop() +} diff --git a/client/internal/updater/manager_test.go b/client/internal/updater/manager_test.go new file mode 100644 index 000000000..107dca2b3 --- /dev/null +++ b/client/internal/updater/manager_test.go @@ -0,0 +1,227 @@ +//go:build windows || darwin + +package updater + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + cProto "github.com/netbirdio/netbird/client/proto" +) + +func Test_LatestVersion(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should notify again when a newer version arrives even within 5 minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: true, + }, + { + name: "Shouldn't update initially, but should update as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = mockUpdate + m.currentVersion = c.daemonVersion + m.autoUpdateSupported = func() bool { return true } + m.Start(context.Background()) + m.SetVersion("latest", false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredInit := ver != "" + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() { + t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + ver, _ = waitForUpdateEvent(sub, 500*time.Millisecond) + triggeredLater := ver != "" + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Later update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Later update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } + + m.Stop() + } +} + +func Test_HandleUpdate(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + latestVersion *v.Version + expectedVersion string + shouldUpdate bool + }{ + { + name: "Install to a specific version should update regardless of if latestVersion is available yet", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.56.0", + shouldUpdate: true, + }, + { + name: "Install to specific version should not update if version matches", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.55.0", + shouldUpdate: false, + }, + { + name: "Install to specific version should not update if current version is newer", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.54.0", + shouldUpdate: false, + }, + { + name: "Install to latest version should update if latest is newer", + daemonVersion: "0.55.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: true, + }, + { + name: "Install to latest version should not update if latest == current", + daemonVersion: "0.56.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if daemon version is invalid", + daemonVersion: "development", + latestVersion: v.Must(v.NewSemver("1.0.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expecting latest and latest version is unavailable", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expected version is invalid", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "development", + shouldUpdate: false, + }, + } + for idx, c := range testMatrix { + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: c.latestVersion} + m.currentVersion = c.daemonVersion + m.autoUpdateSupported = func() bool { return true } + m.Start(context.Background()) + m.SetVersion(c.expectedVersion, false) + + ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond) + updateTriggered := ver != "" + + if updateTriggered { + if c.expectedVersion == "latest" && c.latestVersion != nil && ver != c.latestVersion.String() { + t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver) + } else if c.expectedVersion != "latest" && c.expectedVersion != "development" && ver != c.expectedVersion { + t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.expectedVersion, ver) + } + } + + if updateTriggered != c.shouldUpdate { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) + } + m.Stop() + } +} + +func Test_EnforcedMetadata(t *testing.T) { + // Mode 1 (downloadOnly): no enforced metadata + tmpFile := path.Join(t.TempDir(), "update-test-mode1.json") + recorder := peer.NewRecorder("") + sub := recorder.SubscribeToEvents() + defer recorder.UnsubscribeFromEvents(sub) + + m := NewManager(recorder, statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))} + m.currentVersion = "1.0.0" + m.Start(context.Background()) + m.SetDownloadOnly() + + ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond) + if ver == "" { + t.Fatal("Mode 1: expected new_version_available event") + } + if enforced { + t.Error("Mode 1: expected no enforced metadata") + } + m.Stop() + + // Mode 2 (enforced, forceUpdate=false): enforced metadata present, no auto-install + tmpFile2 := path.Join(t.TempDir(), "update-test-mode2.json") + recorder2 := peer.NewRecorder("") + sub2 := recorder2.SubscribeToEvents() + defer recorder2.UnsubscribeFromEvents(sub2) + + m2 := NewManager(recorder2, statemanager.New(tmpFile2)) + m2.update = &versionUpdateMock{latestVersion: nil} + m2.currentVersion = "1.0.0" + m2.autoUpdateSupported = func() bool { return true } + m2.Start(context.Background()) + m2.SetVersion("1.0.1", false) + + ver, enforced2 := waitForUpdateEvent(sub2, 500*time.Millisecond) + if ver == "" { + t.Fatal("Mode 2: expected new_version_available event") + } + if !enforced2 { + t.Error("Mode 2: expected enforced metadata") + } + m2.Stop() +} + +// ensure the proto import is used +var _ = cProto.SystemEvent_INFO diff --git a/client/internal/updater/manager_test_helpers_test.go b/client/internal/updater/manager_test_helpers_test.go new file mode 100644 index 000000000..c7faee1f4 --- /dev/null +++ b/client/internal/updater/manager_test_helpers_test.go @@ -0,0 +1,56 @@ +package updater + +import ( + "strconv" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" +) + +type versionUpdateMock struct { + latestVersion *v.Version + onUpdate func() +} + +func (m versionUpdateMock) StopWatch() {} + +func (m versionUpdateMock) SetDaemonVersion(newVersion string) bool { + return false +} + +func (m *versionUpdateMock) SetOnUpdateListener(updateFn func()) { + m.onUpdate = updateFn +} + +func (m versionUpdateMock) LatestVersion() *v.Version { + return m.latestVersion +} + +func (m versionUpdateMock) StartFetcher() {} + +// waitForUpdateEvent waits for a new_version_available event, returns the version string or "" on timeout. +func waitForUpdateEvent(sub *peer.EventSubscription, timeout time.Duration) (version string, enforced bool) { + timer := time.NewTimer(timeout) + defer timer.Stop() + for { + select { + case event, ok := <-sub.Events(): + if !ok { + return "", false + } + if val, ok := event.Metadata["new_version_available"]; ok { + enforced := false + if raw, ok := event.Metadata["enforced"]; ok { + if parsed, err := strconv.ParseBool(raw); err == nil { + enforced = parsed + } + } + return val, enforced + } + case <-timer.C: + return "", false + } + } +} diff --git a/client/internal/updatemanager/reposign/artifact.go b/client/internal/updater/reposign/artifact.go similarity index 100% rename from client/internal/updatemanager/reposign/artifact.go rename to client/internal/updater/reposign/artifact.go diff --git a/client/internal/updatemanager/reposign/artifact_test.go b/client/internal/updater/reposign/artifact_test.go similarity index 100% rename from client/internal/updatemanager/reposign/artifact_test.go rename to client/internal/updater/reposign/artifact_test.go diff --git a/client/internal/updatemanager/reposign/certs/root-pub.pem b/client/internal/updater/reposign/certs/root-pub.pem similarity index 100% rename from client/internal/updatemanager/reposign/certs/root-pub.pem rename to client/internal/updater/reposign/certs/root-pub.pem diff --git a/client/internal/updatemanager/reposign/certsdev/root-pub.pem b/client/internal/updater/reposign/certsdev/root-pub.pem similarity index 100% rename from client/internal/updatemanager/reposign/certsdev/root-pub.pem rename to client/internal/updater/reposign/certsdev/root-pub.pem diff --git a/client/internal/updatemanager/reposign/doc.go b/client/internal/updater/reposign/doc.go similarity index 100% rename from client/internal/updatemanager/reposign/doc.go rename to client/internal/updater/reposign/doc.go diff --git a/client/internal/updatemanager/reposign/embed_dev.go b/client/internal/updater/reposign/embed_dev.go similarity index 100% rename from client/internal/updatemanager/reposign/embed_dev.go rename to client/internal/updater/reposign/embed_dev.go diff --git a/client/internal/updatemanager/reposign/embed_prod.go b/client/internal/updater/reposign/embed_prod.go similarity index 100% rename from client/internal/updatemanager/reposign/embed_prod.go rename to client/internal/updater/reposign/embed_prod.go diff --git a/client/internal/updatemanager/reposign/key.go b/client/internal/updater/reposign/key.go similarity index 100% rename from client/internal/updatemanager/reposign/key.go rename to client/internal/updater/reposign/key.go diff --git a/client/internal/updatemanager/reposign/key_test.go b/client/internal/updater/reposign/key_test.go similarity index 100% rename from client/internal/updatemanager/reposign/key_test.go rename to client/internal/updater/reposign/key_test.go diff --git a/client/internal/updatemanager/reposign/revocation.go b/client/internal/updater/reposign/revocation.go similarity index 100% rename from client/internal/updatemanager/reposign/revocation.go rename to client/internal/updater/reposign/revocation.go diff --git a/client/internal/updatemanager/reposign/revocation_test.go b/client/internal/updater/reposign/revocation_test.go similarity index 100% rename from client/internal/updatemanager/reposign/revocation_test.go rename to client/internal/updater/reposign/revocation_test.go diff --git a/client/internal/updatemanager/reposign/root.go b/client/internal/updater/reposign/root.go similarity index 100% rename from client/internal/updatemanager/reposign/root.go rename to client/internal/updater/reposign/root.go diff --git a/client/internal/updatemanager/reposign/root_test.go b/client/internal/updater/reposign/root_test.go similarity index 100% rename from client/internal/updatemanager/reposign/root_test.go rename to client/internal/updater/reposign/root_test.go diff --git a/client/internal/updatemanager/reposign/signature.go b/client/internal/updater/reposign/signature.go similarity index 100% rename from client/internal/updatemanager/reposign/signature.go rename to client/internal/updater/reposign/signature.go diff --git a/client/internal/updatemanager/reposign/signature_test.go b/client/internal/updater/reposign/signature_test.go similarity index 100% rename from client/internal/updatemanager/reposign/signature_test.go rename to client/internal/updater/reposign/signature_test.go diff --git a/client/internal/updatemanager/reposign/verify.go b/client/internal/updater/reposign/verify.go similarity index 98% rename from client/internal/updatemanager/reposign/verify.go rename to client/internal/updater/reposign/verify.go index 0af2a8c9e..f64b26a30 100644 --- a/client/internal/updatemanager/reposign/verify.go +++ b/client/internal/updater/reposign/verify.go @@ -10,7 +10,7 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" + "github.com/netbirdio/netbird/client/internal/updater/downloader" ) const ( diff --git a/client/internal/updatemanager/reposign/verify_test.go b/client/internal/updater/reposign/verify_test.go similarity index 100% rename from client/internal/updatemanager/reposign/verify_test.go rename to client/internal/updater/reposign/verify_test.go diff --git a/client/internal/updater/supported_darwin.go b/client/internal/updater/supported_darwin.go new file mode 100644 index 000000000..b27754366 --- /dev/null +++ b/client/internal/updater/supported_darwin.go @@ -0,0 +1,22 @@ +package updater + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updater/installer" +) + +func isAutoUpdateSupported() bool { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + isBrew := !installer.TypeOfInstaller(ctx).Downloadable() + if isBrew { + log.Warnf("auto-update disabled on Homebrew installation") + return false + } + return true +} diff --git a/client/internal/updater/supported_other.go b/client/internal/updater/supported_other.go new file mode 100644 index 000000000..e09e8c3a3 --- /dev/null +++ b/client/internal/updater/supported_other.go @@ -0,0 +1,7 @@ +//go:build !windows && !darwin + +package updater + +func isAutoUpdateSupported() bool { + return false +} diff --git a/client/internal/updater/supported_windows.go b/client/internal/updater/supported_windows.go new file mode 100644 index 000000000..0c28878c7 --- /dev/null +++ b/client/internal/updater/supported_windows.go @@ -0,0 +1,5 @@ +package updater + +func isAutoUpdateSupported() bool { + return true +} diff --git a/client/internal/updatemanager/update.go b/client/internal/updater/update.go similarity index 90% rename from client/internal/updatemanager/update.go rename to client/internal/updater/update.go index 875b50b49..3056c77e1 100644 --- a/client/internal/updatemanager/update.go +++ b/client/internal/updater/update.go @@ -1,4 +1,4 @@ -package updatemanager +package updater import v "github.com/hashicorp/go-version" diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index aafef41d3..3e2da7f4e 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -160,7 +160,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { c.onHostDnsFn = func([]string) {} cfg.WgIface = interfaceName - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 3879beba3..fd3c18f56 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.33.3 +// protoc v6.33.1 // source: daemon.proto package proto @@ -945,7 +945,6 @@ type UpRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` - AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -994,13 +993,6 @@ func (x *UpRequest) GetUsername() string { return "" } -func (x *UpRequest) GetAutoUpdate() bool { - if x != nil && x.AutoUpdate != nil { - return *x.AutoUpdate - } - return false -} - type UpResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -5032,6 +5024,94 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +type TriggerUpdateRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TriggerUpdateRequest) Reset() { + *x = TriggerUpdateRequest{} + mi := &file_daemon_proto_msgTypes[73] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TriggerUpdateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TriggerUpdateRequest) ProtoMessage() {} + +func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[73] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TriggerUpdateRequest.ProtoReflect.Descriptor instead. +func (*TriggerUpdateRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{73} +} + +type TriggerUpdateResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TriggerUpdateResponse) Reset() { + *x = TriggerUpdateResponse{} + mi := &file_daemon_proto_msgTypes[74] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TriggerUpdateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TriggerUpdateResponse) ProtoMessage() {} + +func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[74] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TriggerUpdateResponse.ProtoReflect.Descriptor instead. +func (*TriggerUpdateResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{74} +} + +func (x *TriggerUpdateResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *TriggerUpdateResponse) GetErrorMsg() string { + if x != nil { + return x.ErrorMsg + } + return "" +} + // GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer type GetPeerSSHHostKeyRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -5043,7 +5123,7 @@ type GetPeerSSHHostKeyRequest struct { func (x *GetPeerSSHHostKeyRequest) Reset() { *x = GetPeerSSHHostKeyRequest{} - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5055,7 +5135,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string { func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5068,7 +5148,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{73} + return file_daemon_proto_rawDescGZIP(), []int{75} } func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { @@ -5095,7 +5175,7 @@ type GetPeerSSHHostKeyResponse struct { func (x *GetPeerSSHHostKeyResponse) Reset() { *x = GetPeerSSHHostKeyResponse{} - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5107,7 +5187,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string { func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5120,7 +5200,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{74} + return file_daemon_proto_rawDescGZIP(), []int{76} } func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { @@ -5162,7 +5242,7 @@ type RequestJWTAuthRequest struct { func (x *RequestJWTAuthRequest) Reset() { *x = RequestJWTAuthRequest{} - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5174,7 +5254,7 @@ func (x *RequestJWTAuthRequest) String() string { func (*RequestJWTAuthRequest) ProtoMessage() {} func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5187,7 +5267,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{75} + return file_daemon_proto_rawDescGZIP(), []int{77} } func (x *RequestJWTAuthRequest) GetHint() string { @@ -5220,7 +5300,7 @@ type RequestJWTAuthResponse struct { func (x *RequestJWTAuthResponse) Reset() { *x = RequestJWTAuthResponse{} - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5232,7 +5312,7 @@ func (x *RequestJWTAuthResponse) String() string { func (*RequestJWTAuthResponse) ProtoMessage() {} func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5245,7 +5325,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{76} + return file_daemon_proto_rawDescGZIP(), []int{78} } func (x *RequestJWTAuthResponse) GetVerificationURI() string { @@ -5310,7 +5390,7 @@ type WaitJWTTokenRequest struct { func (x *WaitJWTTokenRequest) Reset() { *x = WaitJWTTokenRequest{} - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[79] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5322,7 +5402,7 @@ func (x *WaitJWTTokenRequest) String() string { func (*WaitJWTTokenRequest) ProtoMessage() {} func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[77] + mi := &file_daemon_proto_msgTypes[79] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5335,7 +5415,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{77} + return file_daemon_proto_rawDescGZIP(), []int{79} } func (x *WaitJWTTokenRequest) GetDeviceCode() string { @@ -5367,7 +5447,7 @@ type WaitJWTTokenResponse struct { func (x *WaitJWTTokenResponse) Reset() { *x = WaitJWTTokenResponse{} - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5379,7 +5459,7 @@ func (x *WaitJWTTokenResponse) String() string { func (*WaitJWTTokenResponse) ProtoMessage() {} func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5392,7 +5472,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{78} + return file_daemon_proto_rawDescGZIP(), []int{80} } func (x *WaitJWTTokenResponse) GetToken() string { @@ -5425,7 +5505,7 @@ type StartCPUProfileRequest struct { func (x *StartCPUProfileRequest) Reset() { *x = StartCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[81] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5437,7 +5517,7 @@ func (x *StartCPUProfileRequest) String() string { func (*StartCPUProfileRequest) ProtoMessage() {} func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[81] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5450,7 +5530,7 @@ func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{79} + return file_daemon_proto_rawDescGZIP(), []int{81} } // StartCPUProfileResponse confirms CPU profiling has started @@ -5462,7 +5542,7 @@ type StartCPUProfileResponse struct { func (x *StartCPUProfileResponse) Reset() { *x = StartCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5474,7 +5554,7 @@ func (x *StartCPUProfileResponse) String() string { func (*StartCPUProfileResponse) ProtoMessage() {} func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5487,7 +5567,7 @@ func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{80} + return file_daemon_proto_rawDescGZIP(), []int{82} } // StopCPUProfileRequest for stopping CPU profiling @@ -5499,7 +5579,7 @@ type StopCPUProfileRequest struct { func (x *StopCPUProfileRequest) Reset() { *x = StopCPUProfileRequest{} - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[83] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5511,7 +5591,7 @@ func (x *StopCPUProfileRequest) String() string { func (*StopCPUProfileRequest) ProtoMessage() {} func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[81] + mi := &file_daemon_proto_msgTypes[83] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5524,7 +5604,7 @@ func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead. func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{81} + return file_daemon_proto_rawDescGZIP(), []int{83} } // StopCPUProfileResponse confirms CPU profiling has stopped @@ -5536,7 +5616,7 @@ type StopCPUProfileResponse struct { func (x *StopCPUProfileResponse) Reset() { *x = StopCPUProfileResponse{} - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[84] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5548,7 +5628,7 @@ func (x *StopCPUProfileResponse) String() string { func (*StopCPUProfileResponse) ProtoMessage() {} func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[84] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5561,7 +5641,7 @@ func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead. func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{82} + return file_daemon_proto_rawDescGZIP(), []int{84} } type InstallerResultRequest struct { @@ -5572,7 +5652,7 @@ type InstallerResultRequest struct { func (x *InstallerResultRequest) Reset() { *x = InstallerResultRequest{} - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[85] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5584,7 +5664,7 @@ func (x *InstallerResultRequest) String() string { func (*InstallerResultRequest) ProtoMessage() {} func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[83] + mi := &file_daemon_proto_msgTypes[85] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5597,7 +5677,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. func (*InstallerResultRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{83} + return file_daemon_proto_rawDescGZIP(), []int{85} } type InstallerResultResponse struct { @@ -5610,7 +5690,7 @@ type InstallerResultResponse struct { func (x *InstallerResultResponse) Reset() { *x = InstallerResultResponse{} - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[86] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5622,7 +5702,7 @@ func (x *InstallerResultResponse) String() string { func (*InstallerResultResponse) ProtoMessage() {} func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[84] + mi := &file_daemon_proto_msgTypes[86] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5635,7 +5715,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. func (*InstallerResultResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{84} + return file_daemon_proto_rawDescGZIP(), []int{86} } func (x *InstallerResultResponse) GetSuccess() bool { @@ -5667,7 +5747,7 @@ type ExposeServiceRequest struct { func (x *ExposeServiceRequest) Reset() { *x = ExposeServiceRequest{} - mi := &file_daemon_proto_msgTypes[85] + mi := &file_daemon_proto_msgTypes[87] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5679,7 +5759,7 @@ func (x *ExposeServiceRequest) String() string { func (*ExposeServiceRequest) ProtoMessage() {} func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[85] + mi := &file_daemon_proto_msgTypes[87] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5692,7 +5772,7 @@ func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead. func (*ExposeServiceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{85} + return file_daemon_proto_rawDescGZIP(), []int{87} } func (x *ExposeServiceRequest) GetPort() uint32 { @@ -5756,7 +5836,7 @@ type ExposeServiceEvent struct { func (x *ExposeServiceEvent) Reset() { *x = ExposeServiceEvent{} - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[88] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5768,7 +5848,7 @@ func (x *ExposeServiceEvent) String() string { func (*ExposeServiceEvent) ProtoMessage() {} func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[86] + mi := &file_daemon_proto_msgTypes[88] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5781,7 +5861,7 @@ func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead. func (*ExposeServiceEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{86} + return file_daemon_proto_rawDescGZIP(), []int{88} } func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event { @@ -5821,7 +5901,7 @@ type ExposeServiceReady struct { func (x *ExposeServiceReady) Reset() { *x = ExposeServiceReady{} - mi := &file_daemon_proto_msgTypes[87] + mi := &file_daemon_proto_msgTypes[89] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5833,7 +5913,7 @@ func (x *ExposeServiceReady) String() string { func (*ExposeServiceReady) ProtoMessage() {} func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[87] + mi := &file_daemon_proto_msgTypes[89] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5846,7 +5926,7 @@ func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message { // Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead. func (*ExposeServiceReady) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{87} + return file_daemon_proto_rawDescGZIP(), []int{89} } func (x *ExposeServiceReady) GetServiceName() string { @@ -5880,7 +5960,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[91] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5892,7 +5972,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[91] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6016,16 +6096,12 @@ const file_daemon_proto_rawDesc = "" + "\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" + "\bhostname\x18\x02 \x01(\tR\bhostname\",\n" + "\x14WaitSSOLoginResponse\x12\x14\n" + - "\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" + + "\x05email\x18\x01 \x01(\tR\x05email\"v\n" + "\tUpRequest\x12%\n" + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + - "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" + - "\n" + - "autoUpdate\x18\x03 \x01(\bH\x02R\n" + - "autoUpdate\x88\x01\x01B\x0e\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + - "\t_usernameB\r\n" + - "\v_autoUpdate\"\f\n" + + "\t_usernameJ\x04\b\x03\x10\x04\"\f\n" + "\n" + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + @@ -6380,7 +6456,11 @@ const file_daemon_proto_rawDesc = "" + "\x12GetFeaturesRequest\"x\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"<\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"\x16\n" + + "\x14TriggerUpdateRequest\"M\n" + + "\x15TriggerUpdateResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"<\n" + "\x18GetPeerSSHHostKeyRequest\x12 \n" + "\vpeerAddress\x18\x01 \x01(\tR\vpeerAddress\"\x85\x01\n" + "\x19GetPeerSSHHostKeyResponse\x12\x1e\n" + @@ -6453,7 +6533,7 @@ const file_daemon_proto_rawDesc = "" + "\n" + "EXPOSE_TCP\x10\x02\x12\x0e\n" + "\n" + - "EXPOSE_UDP\x10\x032\xac\x15\n" + + "EXPOSE_UDP\x10\x032\xfc\x15\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6485,7 +6565,8 @@ const file_daemon_proto_rawDesc = "" + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" + - "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12N\n" + + "\rTriggerUpdate\x12\x1c.daemon.TriggerUpdateRequest\x1a\x1d.daemon.TriggerUpdateResponse\"\x00\x12Z\n" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" + @@ -6508,7 +6589,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 93) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (ExposeProtocol)(0), // 1: daemon.ExposeProtocol @@ -6588,34 +6669,36 @@ var file_daemon_proto_goTypes = []any{ (*LogoutResponse)(nil), // 75: daemon.LogoutResponse (*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest (*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse - (*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest - (*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse - (*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest - (*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse - (*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest - (*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse - (*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest - (*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse - (*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest - (*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse - (*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest - (*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse - (*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest - (*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent - (*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady - nil, // 93: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 94: daemon.PortInfo.Range - nil, // 95: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 96: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 97: google.protobuf.Timestamp + (*TriggerUpdateRequest)(nil), // 78: daemon.TriggerUpdateRequest + (*TriggerUpdateResponse)(nil), // 79: daemon.TriggerUpdateResponse + (*GetPeerSSHHostKeyRequest)(nil), // 80: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 81: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 82: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 83: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 84: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 85: daemon.WaitJWTTokenResponse + (*StartCPUProfileRequest)(nil), // 86: daemon.StartCPUProfileRequest + (*StartCPUProfileResponse)(nil), // 87: daemon.StartCPUProfileResponse + (*StopCPUProfileRequest)(nil), // 88: daemon.StopCPUProfileRequest + (*StopCPUProfileResponse)(nil), // 89: daemon.StopCPUProfileResponse + (*InstallerResultRequest)(nil), // 90: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 91: daemon.InstallerResultResponse + (*ExposeServiceRequest)(nil), // 92: daemon.ExposeServiceRequest + (*ExposeServiceEvent)(nil), // 93: daemon.ExposeServiceEvent + (*ExposeServiceReady)(nil), // 94: daemon.ExposeServiceReady + nil, // 95: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 96: daemon.PortInfo.Range + nil, // 97: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 98: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 99: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 96, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 98, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 97, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 97, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 96, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 99, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 99, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 98, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState @@ -6626,8 +6709,8 @@ var file_daemon_proto_depIdxs = []int32{ 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 93, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 94, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 95, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 96, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule @@ -6638,13 +6721,13 @@ var file_daemon_proto_depIdxs = []int32{ 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 97, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 95, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 99, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 97, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 96, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 98, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol - 92, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady + 94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady 33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest @@ -6674,52 +6757,54 @@ var file_daemon_proto_depIdxs = []int32{ 72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest 74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest 76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 78, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 80, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 82, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 84, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest - 86, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest - 6, // 69: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 88, // 70: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 90, // 71: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest - 9, // 72: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 11, // 73: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 13, // 74: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 15, // 75: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 17, // 76: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 19, // 77: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 30, // 78: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 32, // 79: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 32, // 80: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 37, // 81: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 39, // 82: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 41, // 83: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 43, // 84: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 46, // 85: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 48, // 86: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 50, // 87: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 52, // 88: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 56, // 89: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 58, // 90: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 60, // 91: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 62, // 92: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 64, // 93: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 66, // 94: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 68, // 95: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 70, // 96: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 73, // 97: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 75, // 98: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 77, // 99: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 79, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 81, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 83, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 85, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse - 87, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse - 7, // 105: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 89, // 106: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 91, // 107: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent - 72, // [72:108] is the sub-list for method output_type - 36, // [36:72] is the sub-list for method input_type + 78, // 64: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest + 80, // 65: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 82, // 66: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 84, // 67: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 86, // 68: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 88, // 69: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 6, // 70: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 90, // 71: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 92, // 72: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest + 9, // 73: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 11, // 74: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 13, // 75: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 15, // 76: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 17, // 77: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 19, // 78: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 30, // 79: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 32, // 80: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 32, // 81: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 37, // 82: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 39, // 83: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 41, // 84: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 43, // 85: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 46, // 86: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 48, // 87: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 50, // 88: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 52, // 89: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 56, // 90: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 58, // 91: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 60, // 92: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 62, // 93: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 64, // 94: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 66, // 95: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 68, // 96: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 70, // 97: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 73, // 98: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 75, // 99: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 77, // 100: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 79, // 101: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse + 81, // 102: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 83, // 103: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 85, // 104: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 87, // 105: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 89, // 106: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 7, // 107: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 91, // 108: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 93, // 109: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent + 73, // [73:110] is the sub-list for method output_type + 36, // [36:73] is the sub-list for method input_type 36, // [36:36] is the sub-list for extension type_name 36, // [36:36] is the sub-list for extension extendee 0, // [0:36] is the sub-list for field type_name @@ -6742,8 +6827,8 @@ func file_daemon_proto_init() { file_daemon_proto_msgTypes[56].OneofWrappers = []any{} file_daemon_proto_msgTypes[58].OneofWrappers = []any{} file_daemon_proto_msgTypes[69].OneofWrappers = []any{} - file_daemon_proto_msgTypes[75].OneofWrappers = []any{} - file_daemon_proto_msgTypes[86].OneofWrappers = []any{ + file_daemon_proto_msgTypes[77].OneofWrappers = []any{} + file_daemon_proto_msgTypes[88].OneofWrappers = []any{ (*ExposeServiceEvent_Ready)(nil), } type x struct{} @@ -6752,7 +6837,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 5, - NumMessages: 91, + NumMessages: 93, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 4dc41d401..efafe3af7 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -85,6 +85,10 @@ service DaemonService { rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {} + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + rpc TriggerUpdate(TriggerUpdateRequest) returns (TriggerUpdateResponse) {} + // GetPeerSSHHostKey retrieves SSH host key for a specific peer rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {} @@ -226,7 +230,7 @@ message WaitSSOLoginResponse { message UpRequest { optional string profileName = 1; optional string username = 2; - optional bool autoUpdate = 3; + reserved 3; } message UpResponse {} @@ -725,6 +729,13 @@ message GetFeaturesResponse{ bool disable_update_settings = 2; } +message TriggerUpdateRequest {} + +message TriggerUpdateResponse { + bool success = 1; + string errorMsg = 2; +} + // GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer message GetPeerSSHHostKeyRequest { // peer IP address or FQDN to get SSH host key for diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 4154dce59..e5bd89597 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -64,6 +64,9 @@ type DaemonServiceClient interface { // Logout disconnects from the network and deletes the peer from the management server Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + TriggerUpdate(ctx context.Context, in *TriggerUpdateRequest, opts ...grpc.CallOption) (*TriggerUpdateResponse, error) // GetPeerSSHHostKey retrieves SSH host key for a specific peer GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) // RequestJWTAuth initiates JWT authentication flow for SSH @@ -363,6 +366,15 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe return out, nil } +func (c *daemonServiceClient) TriggerUpdate(ctx context.Context, in *TriggerUpdateRequest, opts ...grpc.CallOption) (*TriggerUpdateResponse, error) { + out := new(TriggerUpdateResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/TriggerUpdate", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) { out := new(GetPeerSSHHostKeyResponse) err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...) @@ -508,6 +520,9 @@ type DaemonServiceServer interface { // Logout disconnects from the network and deletes the peer from the management server Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) + // TriggerUpdate initiates installation of the pending enforced version. + // Called when the user clicks the install button in the UI (Mode 2 / enforced update). + TriggerUpdate(context.Context, *TriggerUpdateRequest) (*TriggerUpdateResponse, error) // GetPeerSSHHostKey retrieves SSH host key for a specific peer GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) // RequestJWTAuth initiates JWT authentication flow for SSH @@ -613,6 +628,9 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented") } +func (UnimplementedDaemonServiceServer) TriggerUpdate(context.Context, *TriggerUpdateRequest) (*TriggerUpdateResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method TriggerUpdate not implemented") +} func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") } @@ -1157,6 +1175,24 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_TriggerUpdate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TriggerUpdateRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).TriggerUpdate(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/TriggerUpdate", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).TriggerUpdate(ctx, req.(*TriggerUpdateRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetPeerSSHHostKeyRequest) if err := dec(in); err != nil { @@ -1419,6 +1455,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetFeatures", Handler: _DaemonService_GetFeatures_Handler, }, + { + MethodName: "TriggerUpdate", + Handler: _DaemonService_TriggerUpdate_Handler, + }, { MethodName: "GetPeerSSHHostKey", Handler: _DaemonService_GetPeerSSHHostKey_Handler, diff --git a/client/server/event.go b/client/server/event.go index b5c12a3a6..d93151c96 100644 --- a/client/server/event.go +++ b/client/server/event.go @@ -14,6 +14,7 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo }() log.Debug("client subscribed to events") + s.startUpdateManagerForGUI() for { select { diff --git a/client/server/server.go b/client/server/server.go index 69d79d9cd..1d83366ca 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -30,6 +30,8 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" ) @@ -89,6 +91,8 @@ type Server struct { sleepHandler *sleephandler.SleepHandler + updateManager *updater.Manager + jwtCache *jwtCache } @@ -135,6 +139,12 @@ func (s *Server) Start() error { log.Warnf(errRestoreResidualState, err) } + if s.updateManager == nil { + stateMgr := statemanager.New(s.profileManager.GetStatePath()) + s.updateManager = updater.NewManager(s.statusRecorder, stateMgr) + s.updateManager.CheckUpdateSuccess(s.rootCtx) + } + // if current state contains any error, return it // in all other cases we can continue execution only if status is idle and up command was // not in the progress or already successfully established connection. @@ -192,14 +202,14 @@ func (s *Server) Start() error { s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -207,7 +217,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() if s.config.DisableAutoConnect { - if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil { + if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { log.Debugf("run client connection exited with error: %v", err) } log.Tracef("client connection exited") @@ -236,8 +246,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() runOperation := func() error { - err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan) - doInitialAutoUpdate = false + err := s.connect(ctx, profileConfig, statusRecorder, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) return err @@ -717,11 +726,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - var doAutoUpdate bool - if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate { - doAutoUpdate = true - } - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) s.mutex.Unlock() return s.waitForUp(callerCtx) @@ -1623,9 +1628,10 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) return features, nil } -func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error { +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { log.Tracef("running client connection") - client := internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) + client := internal.NewConnectClient(ctx, config, statusRecorder) + client.SetUpdateManager(s.updateManager) client.SetSyncResponsePersistence(s.persistSyncResponse) s.mutex.Lock() @@ -1656,6 +1662,14 @@ func (s *Server) checkUpdateSettingsDisabled() bool { return false } +func (s *Server) startUpdateManagerForGUI() { + if s.updateManager == nil { + return + } + s.updateManager.Start(s.rootCtx) + s.updateManager.NotifyUI() +} + func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() diff --git a/client/server/server_connect_test.go b/client/server/server_connect_test.go index 8d31c2ae6..faea7da39 100644 --- a/client/server/server_connect_test.go +++ b/client/server/server_connect_test.go @@ -22,7 +22,7 @@ func newTestServer() *Server { } func newDummyConnectClient(ctx context.Context) *internal.ConnectClient { - return internal.NewConnectClient(ctx, nil, nil, false) + return internal.NewConnectClient(ctx, nil, nil) } // TestConnectSetsClientWithMutex validates that connect() sets s.connectClient diff --git a/client/server/server_test.go b/client/server/server_test.go index 82079c531..6de23d501 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -113,7 +113,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } diff --git a/client/server/triggerupdate.go b/client/server/triggerupdate.go new file mode 100644 index 000000000..ffcb527e7 --- /dev/null +++ b/client/server/triggerupdate.go @@ -0,0 +1,24 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/proto" +) + +// TriggerUpdate initiates installation of the pending enforced version. +// It is called when the user clicks the install button in the UI (Mode 2 / enforced update). +func (s *Server) TriggerUpdate(ctx context.Context, _ *proto.TriggerUpdateRequest) (*proto.TriggerUpdateResponse, error) { + if s.updateManager == nil { + return &proto.TriggerUpdateResponse{Success: false, ErrorMsg: "update manager not available"}, nil + } + + if err := s.updateManager.Install(ctx); err != nil { + log.Warnf("TriggerUpdate failed: %v", err) + return &proto.TriggerUpdateResponse{Success: false, ErrorMsg: err.Error()}, nil + } + + return &proto.TriggerUpdateResponse{Success: true}, nil +} diff --git a/client/server/updateresult.go b/client/server/updateresult.go index 8e00d5062..8d1ef0e5f 100644 --- a/client/server/updateresult.go +++ b/client/server/updateresult.go @@ -5,7 +5,7 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/client/proto" ) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7af00cd20..0574e53d0 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -34,7 +34,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - protobuf "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" @@ -308,10 +307,11 @@ type serviceClient struct { sshJWTCacheTTL int connected bool - update *version.Update daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool + isEnforcedUpdate bool + lastNotifiedVersion string settingsEnabled bool profilesEnabled bool showNetworks bool @@ -367,7 +367,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, - update: version.NewUpdateAndStart("nb/client-ui"), } s.eventHandler = newEventHandler(s) @@ -828,7 +827,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log return nil } -func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error { +func (s *serviceClient) menuUpClick(ctx context.Context) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { @@ -850,9 +849,7 @@ func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) e return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{ - AutoUpdate: protobuf.Bool(wannaAutoUpdate), - }); err != nil { + if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { return fmt.Errorf("start connection: %w", err) } @@ -933,13 +930,13 @@ func (s *serviceClient) updateStatus() error { systrayIconState = false } - // the updater struct notify by the upgrades available only, but if meanwhile the daemon has successfully - // updated must reset the mUpdate visibility state + // if the daemon version changed (e.g. after a successful update), reset the update indication if s.daemonVersion != status.DaemonVersion { - s.mUpdate.Hide() + if s.daemonVersion != "" { + s.mUpdate.Hide() + s.isUpdateIconActive = false + } s.daemonVersion = status.DaemonVersion - - s.isUpdateIconActive = s.update.SetDaemonVersion(status.DaemonVersion) if !s.isUpdateIconActive { if systrayIconState { systray.SetTemplateIcon(iconConnectedMacOS, s.icConnected) @@ -1091,7 +1088,6 @@ func (s *serviceClient) onTrayReady() { // update exit node menu in case service is already connected go s.updateExitNodes() - s.update.SetOnUpdateListener(s.onUpdateAvailable) go func() { s.getSrvConfig() time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon @@ -1135,6 +1131,13 @@ func (s *serviceClient) onTrayReady() { } } }) + s.eventManager.AddHandler(func(event *proto.SystemEvent) { + if newVersion, ok := event.Metadata["new_version_available"]; ok { + _, enforced := event.Metadata["enforced"] + log.Infof("received new_version_available event: version=%s enforced=%v", newVersion, enforced) + s.onUpdateAvailable(newVersion, enforced) + } + }) go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) @@ -1507,10 +1510,18 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { return &config } -func (s *serviceClient) onUpdateAvailable() { +func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) { s.updateIndicationLock.Lock() defer s.updateIndicationLock.Unlock() + s.isEnforcedUpdate = enforced + if enforced { + s.mUpdate.SetTitle("Install version " + newVersion) + } else { + s.lastNotifiedVersion = "" + s.mUpdate.SetTitle("Download latest version") + } + s.mUpdate.Show() s.isUpdateIconActive = true @@ -1519,6 +1530,11 @@ func (s *serviceClient) onUpdateAvailable() { } else { systray.SetTemplateIcon(iconUpdateDisconnectedMacOS, s.icUpdateDisconnected) } + + if enforced && s.lastNotifiedVersion != newVersion { + s.lastNotifiedVersion = newVersion + s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install")) + } } // onSessionExpire sends a notification to the user when the session expires. diff --git a/client/ui/event/event.go b/client/ui/event/event.go index 4d949416d..b8ed09a5c 100644 --- a/client/ui/event/event.go +++ b/client/ui/event/event.go @@ -107,12 +107,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) { handlers := slices.Clone(e.handlers) e.mu.Unlock() - // critical events are always shown - if !enabled && event.Severity != proto.SystemEvent_CRITICAL { - return - } - - if event.UserMessage != "" { + if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) { title := e.getEventTitle(event) body := event.UserMessage id := event.Metadata["id"] diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 6adf8778c..60a580dae 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -82,7 +82,7 @@ func (h *eventHandler) handleConnectClick() { go func() { defer connectCancel() - if err := h.client.menuUpClick(connectCtx, true); err != nil { + if err := h.client.menuUpClick(connectCtx); err != nil { st, ok := status.FromError(err) if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { log.Debugf("connect operation cancelled by user") @@ -211,9 +211,42 @@ func (h *eventHandler) handleGitHubClick() { } func (h *eventHandler) handleUpdateClick() { - if err := openURL(version.DownloadUrl()); err != nil { - log.Errorf("failed to open download URL: %v", err) + h.client.updateIndicationLock.Lock() + enforced := h.client.isEnforcedUpdate + h.client.updateIndicationLock.Unlock() + + if !enforced { + if err := openURL(version.DownloadUrl()); err != nil { + log.Errorf("failed to open download URL: %v", err) + } + return } + + // prevent blocking against a busy server + h.client.mUpdate.Disable() + go func() { + defer h.client.mUpdate.Enable() + conn, err := h.client.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("failed to get service client for update: %v", err) + _ = openURL(version.DownloadUrl()) + return + } + + resp, err := conn.TriggerUpdate(h.client.ctx, &proto.TriggerUpdateRequest{}) + if err != nil { + log.Errorf("TriggerUpdate failed: %v", err) + _ = openURL(version.DownloadUrl()) + return + } + if !resp.Success { + log.Errorf("TriggerUpdate failed: %s", resp.ErrorMsg) + _ = openURL(version.DownloadUrl()) + return + } + + log.Infof("update triggered via daemon") + }() } func (h *eventHandler) handleNetworksClick() { diff --git a/client/ui/profile.go b/client/ui/profile.go index a38d8918a..74189c9a0 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -397,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func(context.Context, bool) error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -411,7 +411,7 @@ type newProfileMenuArgs struct { profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func(context.Context, bool) error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -579,7 +579,7 @@ func (p *profileMenu) refresh() { connectCtx, connectCancel := context.WithCancel(p.ctx) p.serviceClient.connectCancel = connectCancel - if err := p.upClickCallback(connectCtx, false); err != nil { + if err := p.upClickCallback(connectCtx); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } diff --git a/client/ui/quickactions.go b/client/ui/quickactions.go index 76440d684..bf47ac434 100644 --- a/client/ui/quickactions.go +++ b/client/ui/quickactions.go @@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() { connCmd := connectCommand{ connectClient: func() error { - return s.menuUpClick(s.ctx, false) + return s.menuUpClick(s.ctx) }, } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index c74fa2660..ef417d3cf 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -107,7 +107,8 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: settings.LazyConnectionEnabled, AutoUpdate: &proto.AutoUpdateSettings{ - Version: settings.AutoUpdateVersion, + Version: settings.AutoUpdateVersion, + AlwaysUpdate: settings.AutoUpdateAlways, }, } } diff --git a/management/server/account.go b/management/server/account.go index 01d0eebfa..75db36a5f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -335,7 +335,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || oldSettings.DNSDomain != newSettings.DNSDomain || - oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { + oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion || + oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways { updateAccountPeers = true } @@ -376,6 +377,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleAutoUpdateAlwaysSettings(ctx, oldSettings, newSettings, userID, accountID) am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID) if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { return nil, err @@ -493,6 +495,16 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con } } +func (am *DefaultAccountManager) handleAutoUpdateAlwaysSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways { + if newSettings.AutoUpdateAlways { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateAlwaysEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateAlwaysDisabled, nil) + } + } +} + func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { oldEnabled := oldSettings.PeerExposeEnabled newEnabled := newSettings.PeerExposeEnabled diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 948d599ba..ddc3e00c3 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -220,6 +220,11 @@ const ( // AccountPeerExposeDisabled indicates that a user disabled peer expose for the account AccountPeerExposeDisabled Activity = 115 + // AccountAutoUpdateAlwaysEnabled indicates that a user enabled always auto-update for the account + AccountAutoUpdateAlwaysEnabled Activity = 116 + // AccountAutoUpdateAlwaysDisabled indicates that a user disabled always auto-update for the account + AccountAutoUpdateAlwaysDisabled Activity = 117 + // DomainAdded indicates that a user added a custom domain DomainAdded Activity = 118 // DomainDeleted indicates that a user deleted a custom domain @@ -339,6 +344,8 @@ var activityMap = map[Activity]Code{ UserCreated: {"User created", "user.create"}, AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"}, + AccountAutoUpdateAlwaysEnabled: {"Account auto-update always enabled", "account.setting.auto.update.always.enable"}, + AccountAutoUpdateAlwaysDisabled: {"Account auto-update always disabled", "account.setting.auto.update.always.disable"}, IdentityProviderCreated: {"Identity provider created", "identityprovider.create"}, IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"}, diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 27a57c434..cc5567e3d 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -225,6 +225,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS return nil, fmt.Errorf("invalid AutoUpdateVersion") } } + if req.Settings.AutoUpdateAlways != nil { + returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways + } return returnSettings, nil } @@ -348,6 +351,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, AutoUpdateVersion: &settings.AutoUpdateVersion, + AutoUpdateAlways: &settings.AutoUpdateAlways, EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, LocalAuthDisabled: &settings.LocalAuthDisabled, } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 6cbd5908d..739dfe2f6 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -121,6 +121,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -146,6 +147,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -171,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr("latest"), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -196,6 +199,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -221,6 +225,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), @@ -246,6 +251,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateAlways: br(false), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), diff --git a/management/server/types/settings.go b/management/server/types/settings.go index e165968fc..4ea79ec72 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -61,6 +61,10 @@ type Settings struct { // AutoUpdateVersion client auto-update version AutoUpdateVersion string `gorm:"default:'disabled'"` + // AutoUpdateAlways when true, updates are installed automatically in the background; + // when false, updates require user interaction from the UI + AutoUpdateAlways bool `gorm:"default:false"` + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. // This is a runtime-only field, not stored in the database. EmbeddedIdpEnabled bool `gorm:"-"` @@ -91,6 +95,7 @@ func (s *Settings) Copy() *Settings { DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, AutoUpdateVersion: s.AutoUpdateVersion, + AutoUpdateAlways: s.AutoUpdateAlways, EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, LocalAuthDisabled: s.LocalAuthDisabled, } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index c67231342..6d2967aa9 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -347,6 +347,10 @@ components: description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") type: string example: "0.51.2" + auto_update_always: + description: When true, updates are installed automatically in the background. When false, updates require user interaction from the UI. + type: boolean + example: false embedded_idp_enabled: description: Indicates whether the embedded identity provider (Dex) is enabled for this account. This is a read-only field. type: boolean diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index f218679c0..f5a2b7ced 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1307,6 +1307,9 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { + // AutoUpdateAlways When true, updates are installed automatically in the background. When false, updates require user interaction from the UI. + AutoUpdateAlways *bool `json:"auto_update_always,omitempty"` + // AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") AutoUpdateVersion *string `json:"auto_update_version,omitempty"` diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 3667ae27f..fdbe3a365 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -340,8 +340,8 @@ message PeerConfig { message AutoUpdateSettings { string version = 1; /* - alwaysUpdate = true → Updates happen automatically in the background - alwaysUpdate = false → Updates only happen when triggered by a peer connection + alwaysUpdate = true → Updates are installed automatically in the background + alwaysUpdate = false → Updates require user interaction from the UI */ bool alwaysUpdate = 2; } From 3e6baea4052a47f12a3029508ae5806cd87c2faa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 14 Mar 2026 01:36:44 +0800 Subject: [PATCH 21/27] [management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530) --- client/cmd/expose.go | 144 +- client/internal/expose/manager.go | 8 +- client/internal/expose/request.go | 9 +- client/proto/daemon.pb.go | 48 +- client/proto/daemon.proto | 3 + client/server/server.go | 7 +- .../reverseproxy/accesslogs/accesslogentry.go | 42 +- .../modules/reverseproxy/domain/domain.go | 3 + .../reverseproxy/domain/manager/api.go | 9 +- .../reverseproxy/domain/manager/manager.go | 36 +- .../modules/reverseproxy/proxy/manager.go | 1 + .../reverseproxy/proxy/manager/controller.go | 5 + .../reverseproxy/proxy/manager_mock.go | 14 + .../modules/reverseproxy/service/interface.go | 4 +- .../reverseproxy/service/interface_mock.go | 16 +- .../reverseproxy/service/manager/api.go | 9 +- .../service/manager/expose_tracker_test.go | 57 +- .../service/manager/l4_port_test.go | 582 ++++++ .../reverseproxy/service/manager/manager.go | 364 +++- .../service/manager/manager_test.go | 133 +- .../modules/reverseproxy/service/service.go | 580 ++++-- .../reverseproxy/service/service_test.go | 191 +- management/internals/server/boot.go | 1 + management/internals/server/modules.go | 4 + .../internals/shared/grpc/expose_service.go | 95 +- management/internals/shared/grpc/proxy.go | 185 +- .../internals/shared/grpc/proxy_test.go | 380 +++- management/server/http/handler.go | 3 +- .../testing/testing_tools/channel/channel.go | 2 + management/server/metrics/selfhosted.go | 6 +- management/server/store/sql_store.go | 39 +- management/server/store/store.go | 4 +- management/server/store/store_mock.go | 38 +- management/server/types/account.go | 37 +- proxy/cmd/proxy/cmd/root.go | 73 +- proxy/handle_mapping_stream_test.go | 11 +- proxy/internal/accesslog/logger.go | 70 +- proxy/internal/accesslog/middleware.go | 7 +- proxy/internal/accesslog/requestip.go | 2 +- proxy/internal/acme/manager.go | 9 +- proxy/internal/acme/manager_test.go | 20 +- proxy/internal/auth/middleware.go | 8 +- proxy/internal/auth/oidc.go | 11 +- proxy/internal/auth/password.go | 12 +- proxy/internal/auth/pin.go | 12 +- proxy/internal/conntrack/conn.go | 8 +- proxy/internal/conntrack/hijacked.go | 85 +- proxy/internal/conntrack/hijacked_test.go | 142 ++ proxy/internal/debug/client.go | 14 +- proxy/internal/debug/handler.go | 68 +- proxy/internal/debug/templates/clients.html | 4 +- proxy/internal/debug/templates/index.html | 6 +- proxy/internal/metrics/l4_metrics_test.go | 69 + proxy/internal/metrics/metrics.go | 244 ++- proxy/internal/netutil/errors.go | 40 + proxy/internal/netutil/errors_test.go | 92 + proxy/internal/proxy/context.go | 19 +- proxy/internal/proxy/proxy_bench_test.go | 6 +- proxy/internal/proxy/reverseproxy.go | 42 +- proxy/internal/proxy/reverseproxy_test.go | 20 +- proxy/internal/proxy/servicemapping.go | 12 +- proxy/internal/proxy/trustedproxy.go | 57 +- proxy/internal/proxy/trustedproxy_test.go | 22 +- proxy/internal/roundtrip/netbird.go | 256 +-- .../internal/roundtrip/netbird_bench_test.go | 33 +- proxy/internal/roundtrip/netbird_test.go | 145 +- proxy/internal/tcp/bench_test.go | 133 ++ proxy/internal/tcp/chanlistener.go | 76 + proxy/internal/tcp/peekedconn.go | 39 + proxy/internal/tcp/proxyprotocol.go | 29 + proxy/internal/tcp/proxyprotocol_test.go | 128 ++ proxy/internal/tcp/relay.go | 156 ++ proxy/internal/tcp/relay_test.go | 210 +++ proxy/internal/tcp/router.go | 570 ++++++ proxy/internal/tcp/router_test.go | 1670 +++++++++++++++++ proxy/internal/tcp/snipeek.go | 191 ++ proxy/internal/tcp/snipeek_test.go | 251 +++ proxy/internal/types/types.go | 51 + proxy/internal/types/types_test.go | 54 + proxy/internal/udp/relay.go | 496 +++++ proxy/internal/udp/relay_test.go | 493 +++++ proxy/management_integration_test.go | 10 +- proxy/server.go | 782 +++++++- shared/management/client/grpc.go | 18 +- shared/management/http/api/openapi.yml | 107 +- shared/management/http/api/types.gen.go | 118 +- shared/management/proto/management.pb.go | 53 +- shared/management/proto/management.proto | 3 + shared/management/proto/proxy_service.pb.go | 676 ++++--- shared/management/proto/proxy_service.proto | 16 + 90 files changed, 9611 insertions(+), 1397 deletions(-) create mode 100644 management/internals/modules/reverseproxy/service/manager/l4_port_test.go create mode 100644 proxy/internal/conntrack/hijacked_test.go create mode 100644 proxy/internal/metrics/l4_metrics_test.go create mode 100644 proxy/internal/netutil/errors.go create mode 100644 proxy/internal/netutil/errors_test.go create mode 100644 proxy/internal/tcp/bench_test.go create mode 100644 proxy/internal/tcp/chanlistener.go create mode 100644 proxy/internal/tcp/peekedconn.go create mode 100644 proxy/internal/tcp/proxyprotocol.go create mode 100644 proxy/internal/tcp/proxyprotocol_test.go create mode 100644 proxy/internal/tcp/relay.go create mode 100644 proxy/internal/tcp/relay_test.go create mode 100644 proxy/internal/tcp/router.go create mode 100644 proxy/internal/tcp/router_test.go create mode 100644 proxy/internal/tcp/snipeek.go create mode 100644 proxy/internal/tcp/snipeek_test.go create mode 100644 proxy/internal/types/types_test.go create mode 100644 proxy/internal/udp/relay.go create mode 100644 proxy/internal/udp/relay_test.go diff --git a/client/cmd/expose.go b/client/cmd/expose.go index 991d3ab86..1334617d8 100644 --- a/client/cmd/expose.go +++ b/client/cmd/expose.go @@ -22,20 +22,24 @@ import ( var pinRegexp = regexp.MustCompile(`^\d{6}$`) var ( - exposePin string - exposePassword string - exposeUserGroups []string - exposeDomain string - exposeNamePrefix string - exposeProtocol string + exposePin string + exposePassword string + exposeUserGroups []string + exposeDomain string + exposeNamePrefix string + exposeProtocol string + exposeExternalPort uint16 ) var exposeCmd = &cobra.Command{ - Use: "expose ", - Short: "Expose a local port via the NetBird reverse proxy", - Args: cobra.ExactArgs(1), - Example: "netbird expose --with-password safe-pass 8080", - RunE: exposeFn, + Use: "expose ", + Short: "Expose a local port via the NetBird reverse proxy", + Args: cobra.ExactArgs(1), + Example: ` netbird expose --with-password safe-pass 8080 + netbird expose --protocol tcp 5432 + netbird expose --protocol tcp --with-external-port 5433 5432 + netbird expose --protocol tls --with-custom-domain tls.example.com 4443`, + RunE: exposeFn, } func init() { @@ -44,7 +48,52 @@ func init() { exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)") exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)") exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)") - exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)") + exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)") + exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)") +} + +// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags. +func isClusterProtocol(protocol string) bool { + switch strings.ToLower(protocol) { + case "tcp", "udp", "tls": + return true + default: + return false + } +} + +// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP) +// where domain display doesn't apply. TLS uses SNI so it has a domain. +func isPortBasedProtocol(protocol string) bool { + switch strings.ToLower(protocol) { + case "tcp", "udp": + return true + default: + return false + } +} + +// extractPort returns the port portion of a URL like "tcp://host:12345", or +// falls back to the given default formatted as a string. +func extractPort(serviceURL string, fallback uint16) string { + u := serviceURL + if idx := strings.Index(u, "://"); idx != -1 { + u = u[idx+3:] + } + if i := strings.LastIndex(u, ":"); i != -1 { + if p := u[i+1:]; p != "" { + return p + } + } + return strconv.FormatUint(uint64(fallback), 10) +} + +// resolveExternalPort returns the effective external port, defaulting to the target port. +func resolveExternalPort(targetPort uint64) uint16 { + if exposeExternalPort != 0 { + return exposeExternalPort + } + return uint16(targetPort) } func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { @@ -57,7 +106,15 @@ func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { } if !isProtocolValid(exposeProtocol) { - return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol) + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) + } + + if isClusterProtocol(exposeProtocol) { + if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 { + return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol) + } + } else if cmd.Flags().Changed("with-external-port") { + return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol) } if exposePin != "" && !pinRegexp.MatchString(exposePin) { @@ -76,7 +133,12 @@ func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) { } func isProtocolValid(exposeProtocol string) bool { - return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https" + switch strings.ToLower(exposeProtocol) { + case "http", "https", "tcp", "udp", "tls": + return true + default: + return false + } } func exposeFn(cmd *cobra.Command, args []string) error { @@ -123,7 +185,7 @@ func exposeFn(cmd *cobra.Command, args []string) error { return err } - stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{ + req := &proto.ExposeServiceRequest{ Port: uint32(port), Protocol: protocol, Pin: exposePin, @@ -131,7 +193,12 @@ func exposeFn(cmd *cobra.Command, args []string) error { UserGroups: exposeUserGroups, Domain: exposeDomain, NamePrefix: exposeNamePrefix, - }) + } + if isClusterProtocol(exposeProtocol) { + req.ListenPort = uint32(resolveExternalPort(port)) + } + + stream, err := client.ExposeService(ctx, req) if err != nil { return fmt.Errorf("expose service: %w", err) } @@ -149,8 +216,14 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) { return proto.ExposeProtocol_EXPOSE_HTTP, nil case "https": return proto.ExposeProtocol_EXPOSE_HTTPS, nil + case "tcp": + return proto.ExposeProtocol_EXPOSE_TCP, nil + case "udp": + return proto.ExposeProtocol_EXPOSE_UDP, nil + case "tls": + return proto.ExposeProtocol_EXPOSE_TLS, nil default: - return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol) + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) } } @@ -160,20 +233,33 @@ func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServ return fmt.Errorf("receive expose event: %w", err) } - switch e := event.Event.(type) { - case *proto.ExposeServiceEvent_Ready: - cmd.Println("Service exposed successfully!") - cmd.Printf(" Name: %s\n", e.Ready.ServiceName) - cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl) - cmd.Printf(" Domain: %s\n", e.Ready.Domain) - cmd.Printf(" Protocol: %s\n", exposeProtocol) - cmd.Printf(" Port: %d\n", port) - cmd.Println() - cmd.Println("Press Ctrl+C to stop exposing.") - return nil - default: + ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready) + if !ok { return fmt.Errorf("unexpected expose event: %T", event.Event) } + printExposeReady(cmd, ready.Ready, port) + return nil +} + +func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) { + cmd.Println("Service exposed successfully!") + cmd.Printf(" Name: %s\n", r.ServiceName) + if r.ServiceUrl != "" { + cmd.Printf(" URL: %s\n", r.ServiceUrl) + } + if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) { + cmd.Printf(" Domain: %s\n", r.Domain) + } + cmd.Printf(" Protocol: %s\n", exposeProtocol) + cmd.Printf(" Internal: %d\n", port) + if isClusterProtocol(exposeProtocol) { + cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port))) + } + if r.PortAutoAssigned && exposeExternalPort != 0 { + cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort) + } + cmd.Println() + cmd.Println("Press Ctrl+C to stop exposing.") } func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error { diff --git a/client/internal/expose/manager.go b/client/internal/expose/manager.go index 8cd93685e..c59a1a7bd 100644 --- a/client/internal/expose/manager.go +++ b/client/internal/expose/manager.go @@ -12,9 +12,10 @@ const renewTimeout = 10 * time.Second // Response holds the response from exposing a service. type Response struct { - ServiceName string - ServiceURL string - Domain string + ServiceName string + ServiceURL string + Domain string + PortAutoAssigned bool } type Request struct { @@ -25,6 +26,7 @@ type Request struct { Pin string Password string UserGroups []string + ListenPort uint16 } type ManagementClient interface { diff --git a/client/internal/expose/request.go b/client/internal/expose/request.go index 7e12d0513..bff4f2ce7 100644 --- a/client/internal/expose/request.go +++ b/client/internal/expose/request.go @@ -15,6 +15,7 @@ func NewRequest(req *daemonProto.ExposeServiceRequest) *Request { UserGroups: req.UserGroups, Domain: req.Domain, NamePrefix: req.NamePrefix, + ListenPort: uint16(req.ListenPort), } } @@ -27,13 +28,15 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest { Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, + ListenPort: req.ListenPort, } } func fromClientExposeResponse(response *mgm.ExposeResponse) *Response { return &Response{ - ServiceName: response.ServiceName, - Domain: response.Domain, - ServiceURL: response.ServiceURL, + ServiceName: response.ServiceName, + Domain: response.Domain, + ServiceURL: response.ServiceURL, + PortAutoAssigned: response.PortAutoAssigned, } } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index fd3c18f56..fa0b2f93b 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -95,6 +95,7 @@ const ( ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1 ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2 ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3 + ExposeProtocol_EXPOSE_TLS ExposeProtocol = 4 ) // Enum value maps for ExposeProtocol. @@ -104,12 +105,14 @@ var ( 1: "EXPOSE_HTTPS", 2: "EXPOSE_TCP", 3: "EXPOSE_UDP", + 4: "EXPOSE_TLS", } ExposeProtocol_value = map[string]int32{ "EXPOSE_HTTP": 0, "EXPOSE_HTTPS": 1, "EXPOSE_TCP": 2, "EXPOSE_UDP": 3, + "EXPOSE_TLS": 4, } ) @@ -5741,6 +5744,7 @@ type ExposeServiceRequest struct { UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"` Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"` NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"` + ListenPort uint32 `protobuf:"varint,8,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -5824,6 +5828,13 @@ func (x *ExposeServiceRequest) GetNamePrefix() string { return "" } +func (x *ExposeServiceRequest) GetListenPort() uint32 { + if x != nil { + return x.ListenPort + } + return 0 +} + type ExposeServiceEvent struct { state protoimpl.MessageState `protogen:"open.v1"` // Types that are valid to be assigned to Event: @@ -5891,12 +5902,13 @@ type ExposeServiceEvent_Ready struct { func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {} type ExposeServiceReady struct { - state protoimpl.MessageState `protogen:"open.v1"` - ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` - ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` - Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` + ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + PortAutoAssigned bool `protobuf:"varint,4,opt,name=port_auto_assigned,json=portAutoAssigned,proto3" json:"port_auto_assigned,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExposeServiceReady) Reset() { @@ -5950,6 +5962,13 @@ func (x *ExposeServiceReady) GetDomain() string { return "" } +func (x *ExposeServiceReady) GetPortAutoAssigned() bool { + if x != nil { + return x.PortAutoAssigned + } + return false +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -6499,7 +6518,7 @@ const file_daemon_proto_rawDesc = "" + "\x16InstallerResultRequest\"O\n" + "\x17InstallerResultResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + - "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\x87\x02\n" + "\x14ExposeServiceRequest\x12\x12\n" + "\x04port\x18\x01 \x01(\rR\x04port\x122\n" + "\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" + @@ -6509,15 +6528,18 @@ const file_daemon_proto_rawDesc = "" + "userGroups\x12\x16\n" + "\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" + "\vname_prefix\x18\a \x01(\tR\n" + - "namePrefix\"Q\n" + + "namePrefix\x12\x1f\n" + + "\vlisten_port\x18\b \x01(\rR\n" + + "listenPort\"Q\n" + "\x12ExposeServiceEvent\x122\n" + "\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" + - "\x05event\"p\n" + + "\x05event\"\x9e\x01\n" + "\x12ExposeServiceReady\x12!\n" + "\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" + "\vservice_url\x18\x02 \x01(\tR\n" + "serviceUrl\x12\x16\n" + - "\x06domain\x18\x03 \x01(\tR\x06domain*b\n" + + "\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" + + "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -6526,14 +6548,16 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a*S\n" + + "\x05TRACE\x10\a*c\n" + "\x0eExposeProtocol\x12\x0f\n" + "\vEXPOSE_HTTP\x10\x00\x12\x10\n" + "\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" + "\n" + "EXPOSE_TCP\x10\x02\x12\x0e\n" + "\n" + - "EXPOSE_UDP\x10\x032\xfc\x15\n" + + "EXPOSE_UDP\x10\x03\x12\x0e\n" + + "\n" + + "EXPOSE_TLS\x10\x042\xfc\x15\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index efafe3af7..89302c8c3 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -821,6 +821,7 @@ enum ExposeProtocol { EXPOSE_HTTPS = 1; EXPOSE_TCP = 2; EXPOSE_UDP = 3; + EXPOSE_TLS = 4; } message ExposeServiceRequest { @@ -831,6 +832,7 @@ message ExposeServiceRequest { repeated string user_groups = 5; string domain = 6; string name_prefix = 7; + uint32 listen_port = 8; } message ExposeServiceEvent { @@ -843,4 +845,5 @@ message ExposeServiceReady { string service_name = 1; string service_url = 2; string domain = 3; + bool port_auto_assigned = 4; } diff --git a/client/server/server.go b/client/server/server.go index 1d83366ca..7c1e70692 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1378,9 +1378,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon if err := srv.Send(&proto.ExposeServiceEvent{ Event: &proto.ExposeServiceEvent_Ready{ Ready: &proto.ExposeServiceReady{ - ServiceName: result.ServiceName, - ServiceUrl: result.ServiceURL, - Domain: result.Domain, + ServiceName: result.ServiceName, + ServiceUrl: result.ServiceURL, + Domain: result.Domain, + PortAutoAssigned: result.PortAutoAssigned, }, }, }); err != nil { diff --git a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go index 0bcc59b68..619a34684 100644 --- a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go +++ b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go @@ -10,6 +10,15 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +// AccessLogProtocol identifies the transport protocol of an access log entry. +type AccessLogProtocol string + +const ( + AccessLogProtocolHTTP AccessLogProtocol = "http" + AccessLogProtocolTCP AccessLogProtocol = "tcp" + AccessLogProtocolUDP AccessLogProtocol = "udp" +) + type AccessLogEntry struct { ID string `gorm:"primaryKey"` AccountID string `gorm:"index"` @@ -22,10 +31,11 @@ type AccessLogEntry struct { Duration time.Duration `gorm:"index"` StatusCode int `gorm:"index"` Reason string - UserId string `gorm:"index"` - AuthMethodUsed string `gorm:"index"` - BytesUpload int64 `gorm:"index"` - BytesDownload int64 `gorm:"index"` + UserId string `gorm:"index"` + AuthMethodUsed string `gorm:"index"` + BytesUpload int64 `gorm:"index"` + BytesDownload int64 `gorm:"index"` + Protocol AccessLogProtocol `gorm:"index"` } // FromProto creates an AccessLogEntry from a proto.AccessLog @@ -43,17 +53,22 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) { a.AccountID = serviceLog.GetAccountId() a.BytesUpload = serviceLog.GetBytesUpload() a.BytesDownload = serviceLog.GetBytesDownload() + a.Protocol = AccessLogProtocol(serviceLog.GetProtocol()) if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { - if ip, err := netip.ParseAddr(sourceIP); err == nil { - a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice()) + if addr, err := netip.ParseAddr(sourceIP); err == nil { + addr = addr.Unmap() + a.GeoLocation.ConnectionIP = net.IP(addr.AsSlice()) } } - if !serviceLog.GetAuthSuccess() { - a.Reason = "Authentication failed" - } else if serviceLog.GetResponseCode() >= 400 { - a.Reason = "Request failed" + // Only set reason for HTTP entries. L4 entries have no auth or status code. + if a.Protocol == "" || a.Protocol == AccessLogProtocolHTTP { + if !serviceLog.GetAuthSuccess() { + a.Reason = "Authentication failed" + } else if serviceLog.GetResponseCode() >= 400 { + a.Reason = "Request failed" + } } } @@ -90,6 +105,12 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { cityName = &a.GeoLocation.CityName } + var protocol *string + if a.Protocol != "" { + p := string(a.Protocol) + protocol = &p + } + return &api.ProxyAccessLog{ Id: a.ID, ServiceId: a.ServiceID, @@ -107,5 +128,6 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { CityName: cityName, BytesUpload: a.BytesUpload, BytesDownload: a.BytesDownload, + Protocol: protocol, } } diff --git a/management/internals/modules/reverseproxy/domain/domain.go b/management/internals/modules/reverseproxy/domain/domain.go index 83fd669af..861d026a7 100644 --- a/management/internals/modules/reverseproxy/domain/domain.go +++ b/management/internals/modules/reverseproxy/domain/domain.go @@ -14,6 +14,9 @@ type Domain struct { TargetCluster string // The proxy cluster this domain should be validated against Type Type `gorm:"-"` Validated bool + // SupportsCustomPorts is populated at query time for free domains from the + // proxy cluster capabilities. Not persisted. + SupportsCustomPorts *bool `gorm:"-"` } // EventMeta returns activity event metadata for a domain diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index 2fbcdd5b8..d26a6a418 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -42,10 +42,11 @@ func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType { func domainToApi(d *domain.Domain) api.ReverseProxyDomain { resp := api.ReverseProxyDomain{ - Domain: d.Domain, - Id: d.ID, - Type: domainTypeToApi(d.Type), - Validated: d.Validated, + Domain: d.Domain, + Id: d.ID, + Type: domainTypeToApi(d.Type), + Validated: d.Validated, + SupportsCustomPorts: d.SupportsCustomPorts, } if d.TargetCluster != "" { resp.TargetCluster = &d.TargetCluster diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 8bbc98726..813027ea2 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -33,11 +33,16 @@ type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) } +type clusterCapabilities interface { + ClusterSupportsCustomPorts(clusterAddr string) *bool +} + type Manager struct { - store store - validator domain.Validator - proxyManager proxyManager - permissionsManager permissions.Manager + store store + validator domain.Validator + proxyManager proxyManager + clusterCapabilities clusterCapabilities + permissionsManager permissions.Manager accountManager account.Manager } @@ -51,6 +56,11 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio } } +// SetClusterCapabilities sets the cluster capabilities provider for domain queries. +func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) { + m.clusterCapabilities = caps +} + func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { @@ -80,24 +90,32 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d }).Debug("getting domains with proxy allow list") for _, cluster := range allowList { - ret = append(ret, &domain.Domain{ + d := &domain.Domain{ Domain: cluster, AccountID: accountID, Type: domain.TypeFree, Validated: true, - }) + } + if m.clusterCapabilities != nil { + d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster) + } + ret = append(ret, d) } // Add custom domains. for _, d := range domains { - ret = append(ret, &domain.Domain{ + cd := &domain.Domain{ ID: d.ID, Domain: d.Domain, AccountID: accountID, TargetCluster: d.TargetCluster, Type: domain.TypeCustom, Validated: d.Validated, - }) + } + if m.clusterCapabilities != nil && d.TargetCluster != "" { + cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster) + } + ret = append(ret, cd) } return ret, nil @@ -298,7 +316,7 @@ func extractClusterFromCustomDomains(domain string, customDomains []*domain.Doma // It matches the domain suffix against available clusters and returns the matching cluster. func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) { for _, cluster := range availableClusters { - if strings.HasSuffix(domain, "."+cluster) { + if domain == cluster || strings.HasSuffix(domain, "."+cluster) { return cluster, true } } diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 15f2f9f54..67a8e74fa 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -33,4 +33,5 @@ type Controller interface { RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error GetProxiesForCluster(clusterAddr string) []string + ClusterSupportsCustomPorts(clusterAddr string) *bool } diff --git a/management/internals/modules/reverseproxy/proxy/manager/controller.go b/management/internals/modules/reverseproxy/proxy/manager/controller.go index e5b3e9886..acb49c45b 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/controller.go +++ b/management/internals/modules/reverseproxy/proxy/manager/controller.go @@ -72,6 +72,11 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster return nil } +// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports. +func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool { + return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr) +} + // GetProxiesForCluster returns all proxy IDs registered for a specific cluster. func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string { proxySet, ok := c.clusterProxies.Load(clusterAddr) diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index d9645ba88..b07a21122 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -144,6 +144,20 @@ func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig)) } +// ClusterSupportsCustomPorts mocks base method. +func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. +func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr) +} + // GetProxiesForCluster mocks base method. func (m *MockController) GetProxiesForCluster(clusterAddr string) []string { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index b420f22a8..39fd7e3ae 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -22,7 +22,7 @@ type Manager interface { GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) - RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error - StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error + RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error + StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StartExposeReaper(ctx context.Context) } diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index 727b2c7de..bdc1f3e65 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -211,17 +211,17 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter } // RenewServiceFromPeer mocks base method. -func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer. -func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, serviceID) } // SetCertificateIssuedAt mocks base method. @@ -265,17 +265,17 @@ func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Ca } // StopServiceFromPeer mocks base method. -func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // StopServiceFromPeer indicates an expected call of StopServiceFromPeer. -func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, serviceID) } // UpdateService mocks base method. diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index f28b633b8..c53219d2e 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -11,19 +11,22 @@ import ( domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) type handler struct { - manager rpservice.Manager + manager rpservice.Manager + permissionsManager permissions.Manager } // RegisterEndpoints registers all service HTTP endpoints. -func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { +func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) { h := &handler{ - manager: manager, + manager: manager, + permissionsManager: permissionsManager, } domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go index c831b4a22..6ff8343b9 100644 --- a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go @@ -18,8 +18,8 @@ func TestReapExpiredExposes(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) @@ -28,8 +28,8 @@ func TestReapExpiredExposes(t *testing.T) { // Create a non-expired service resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8081, - Protocol: "http", + Port: 8081, + Mode: "http", }) require.NoError(t, err) @@ -49,15 +49,16 @@ func TestReapAlreadyDeletedService(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) expireEphemeralService(t, testStore, testAccountID, resp.Domain) // Delete the service before reaping - err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) // Reaping should handle the already-deleted service gracefully @@ -70,8 +71,8 @@ func TestConcurrentReapAndRenew(t *testing.T) { for i := range 5 { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080 + i, - Protocol: "http", + Port: uint16(8080 + i), + Mode: "http", }) require.NoError(t, err) } @@ -108,17 +109,19 @@ func TestRenewEphemeralService(t *testing.T) { t.Run("renew succeeds for active service", func(t *testing.T) { resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8082, - Protocol: "http", + Port: 8082, + Mode: "http", }) require.NoError(t, err) - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, lookupErr) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID) require.NoError(t, err) }) t.Run("renew fails for nonexistent domain", func(t *testing.T) { - err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com") + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id") require.Error(t, err) assert.Contains(t, err.Error(), "no active expose session") }) @@ -133,8 +136,8 @@ func TestCountAndExistsEphemeralServices(t *testing.T) { assert.Equal(t, int64(0), count) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8083, - Protocol: "http", + Port: 8083, + Mode: "http", }) require.NoError(t, err) @@ -157,15 +160,15 @@ func TestMaxExposesPerPeerEnforced(t *testing.T) { for i := range maxExposesPerPeer { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8090 + i, - Protocol: "http", + Port: uint16(8090 + i), + Mode: "http", }) require.NoError(t, err, "expose %d should succeed", i) } _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 9999, - Protocol: "http", + Port: 9999, + Mode: "http", }) require.Error(t, err) assert.Contains(t, err.Error(), "maximum number of active expose sessions") @@ -176,8 +179,8 @@ func TestReapSkipsRenewedService(t *testing.T) { ctx := context.Background() resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8086, - Protocol: "http", + Port: 8086, + Mode: "http", }) require.NoError(t, err) @@ -185,7 +188,9 @@ func TestReapSkipsRenewedService(t *testing.T) { expireEphemeralService(t, testStore, testAccountID, resp.Domain) // Renew it before the reaper runs - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svc, err := testStore.GetServiceByDomain(ctx, resp.Domain) + require.NoError(t, err) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID) require.NoError(t, err) // Reaper should skip it because the re-check sees a fresh timestamp @@ -195,6 +200,14 @@ func TestReapSkipsRenewedService(t *testing.T) { require.NoError(t, err, "renewed service should survive reaping") } +// resolveServiceIDByDomain looks up a service ID by domain in tests. +func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string { + t.Helper() + svc, err := s.GetServiceByDomain(context.Background(), domain) + require.NoError(t, err) + return svc.ID +} + // expireEphemeralService backdates meta_last_renewed_at to force expiration. func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) { t.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go new file mode 100644 index 000000000..c7a61ddcf --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -0,0 +1,582 @@ +package manager + +import ( + "context" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/mock_server" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +const testCluster = "test-cluster" + +func boolPtr(v bool) *bool { return &v } + +// setupL4Test creates a manager with a mock proxy controller for L4 port tests. +func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) { + t.Helper() + + ctrl := gomock.NewController(t) + + ctx := context.Background() + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + err = testStore.SaveAccount(ctx, &types.Account{ + Id: testAccountID, + CreatedBy: testUserID, + Settings: &types.Settings{ + PeerExposeEnabled: true, + PeerExposeGroups: []string{testGroupID}, + }, + Users: map[string]*types.User{ + testUserID: { + Id: testUserID, + AccountID: testAccountID, + Role: types.UserRoleAdmin, + }, + }, + Peers: map[string]*nbpeer.Peer{ + testPeerID: { + ID: testPeerID, + AccountID: testAccountID, + Key: "test-key", + DNSLabel: "test-peer", + Name: "test-peer", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, + }, + Groups: map[string]*types.Group{ + testGroupID: { + ID: testGroupID, + AccountID: testAccountID, + Name: "Expose Group", + }, + }, + }) + require.NoError(t, err) + + err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID) + require.NoError(t, err) + + mockCtrl := proxy.NewMockController(ctrl) + mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes() + mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes() + + accountMgr := &mock_server.MockAccountManager{ + StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { + return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) + }, + } + + mgr := &Manager{ + store: testStore, + accountManager: accountMgr, + permissionsManager: permissions.NewManager(testStore), + proxyController: mockCtrl, + clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, + } + mgr.exposeReaper = &exposeReaper{manager: mgr} + + return mgr, testStore, mockCtrl +} + +// seedService creates a service directly in the store for test setup. +func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service { + t.Helper() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: name, + Mode: protocol, + Domain: domain, + ProxyCluster: cluster, + ListenPort: port, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + err := s.CreateService(context.Background(), svc) + require.NoError(t, err) + return svc +} + +func TestPortConflict_TCPSamePortCluster(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "conflicting-tcp", + Mode: "tcp", + Domain: "conflicting-tcp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 5432, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TCP+TCP on same port/cluster should be rejected") + assert.Contains(t, err.Error(), "already in use") +} + +func TestPortConflict_UDPSamePortCluster(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "conflicting-udp", + Mode: "udp", + Domain: "conflicting-udp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 5432, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "UDP+UDP on same port/cluster should be rejected") + assert.Contains(t, err.Error(), "already in use") +} + +func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "new-tls", + Mode: "tls", + Domain: "app2.example.com", + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)") +} + +func TestPortConflict_TLSSamePortSameDomain(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "duplicate-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TLS+TLS on same domain should be rejected") + assert.Contains(t, err.Error(), "domain already taken") +} + +func TestPortConflict_TLSAndTCPSamePort(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "new-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 443, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)") +} + +func TestAutoAssign_TCPNoListenPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "auto-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax, + "auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax) + assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set") +} + +func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it") + assert.Contains(t, err.Error(), "custom ports") +} + +func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 9999, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability") + assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden") + assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS") +} + +func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "ephemeral-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "ephemeral", + SourcePeer: testPeerID, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc) + require.NoError(t, err) + assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden") + assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax, + "auto-assigned port %d should be in range", svc.ListenPort) + assert.True(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "ephemeral-tls", + Mode: "tls", + Domain: "app.example.com", + ProxyCluster: testCluster, + ListenPort: 9999, + Enabled: true, + Source: "ephemeral", + SourcePeer: testPeerID, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc) + require.NoError(t, err) + assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden") + assert.False(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_AvoidsExistingPorts(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existingPort := uint16(20000) + seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort) + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "auto-tcp", + Mode: "tcp", + Domain: "auto-tcp." + testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing") + assert.True(t, svc.PortAutoAssigned) +} + +func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + svc := &rpservice.Service{ + AccountID: testAccountID, + Name: "custom-tcp", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 5555, + Enabled: true, + Source: "permanent", + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true}, + }, + } + svc.InitNewRecord() + + err := mgr.persistNewService(ctx, testAccountID, svc) + require.NoError(t, err) + assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported") + assert.False(t, svc.PortAutoAssigned) +} + +func TestUpdate_PreservesExistingListenPort(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345) + + updated := &rpservice.Service{ + ID: existing.ID, + AccountID: testAccountID, + Name: "tcp-svc-renamed", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 0, + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + + _, err := mgr.persistServiceUpdate(ctx, testAccountID, updated) + require.NoError(t, err) + assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0") +} + +func TestUpdate_AllowsPortChange(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345) + + updated := &rpservice.Service{ + ID: existing.ID, + AccountID: testAccountID, + Name: "tcp-svc", + Mode: "tcp", + Domain: testCluster, + ProxyCluster: testCluster, + ListenPort: 54321, + Enabled: true, + Targets: []*rpservice.Target{ + {AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true}, + }, + } + + _, err := mgr.persistServiceUpdate(ctx, testAccountID, updated) + require.NoError(t, err) + assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied") +} + +func TestCreateServiceFromPeer_TCP(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + }) + require.NoError(t, err) + + assert.NotEmpty(t, resp.ServiceName) + assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain") + assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports") + assert.Contains(t, resp.ServiceURL, "tcp://") +} + +func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + ListenPort: 15432, + }) + require.NoError(t, err) + + assert.False(t, resp.PortAutoAssigned) + assert.Contains(t, resp.ServiceURL, ":15432") +} + +func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 5432, + Mode: "tcp", + }) + require.NoError(t, err) + + // When no explicit listen port, defaults to target port + assert.Contains(t, resp.ServiceURL, ":5432") + assert.False(t, resp.PortAutoAssigned) +} + +func TestCreateServiceFromPeer_TLS(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(false)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 443, + Mode: "tls", + }) + require.NoError(t, err) + + assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain") + assert.Contains(t, resp.ServiceURL, "tls://") + assert.Contains(t, resp.ServiceURL, ":443") + // TLS always keeps its port (not port-based protocol for auto-assign) + assert.False(t, resp.PortAutoAssigned) +} + +func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) { + mgr, testStore, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "tcp", + }) + require.NoError(t, err) + + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.NoError(t, err) + + // Renew after stop should fail + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) + require.Error(t, err) +} + +func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) { + mgr, _, _ := setupL4Test(t, boolPtr(true)) + ctx := context.Background() + + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8080, + Mode: "tcp", + Pin: "123456", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication is not supported") +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index cae3d3bda..c40961fdc 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "math/rand/v2" + "os" "slices" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -23,6 +25,45 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +const ( + defaultAutoAssignPortMin uint16 = 10000 + defaultAutoAssignPortMax uint16 = 49151 + + // EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports. + EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN" + // EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports. + EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX" +) + +var ( + autoAssignPortMin = defaultAutoAssignPortMin + autoAssignPortMax = defaultAutoAssignPortMax +) + +func init() { + autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin) + autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax) + if autoAssignPortMin > autoAssignPortMax { + log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults", + EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax) + autoAssignPortMin = defaultAutoAssignPortMin + autoAssignPortMax = defaultAutoAssignPortMax + } +} + +func portFromEnv(key string, fallback uint16) uint16 { + val := os.Getenv(key) + if val == "" { + return fallback + } + n, err := strconv.ParseUint(val, 10, 16) + if err != nil { + log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err) + return fallback + } + return uint16(n) +} + const unknownHostPlaceholder = "unknown" // ClusterDeriver derives the proxy cluster from a domain. @@ -115,6 +156,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s * return fmt.Errorf("unknown target type: %s", target.TargetType) } } + return nil } @@ -197,55 +239,19 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri return nil } -func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error { +func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil { + if svc.Domain != "" { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + return err + } + } + + if err := m.ensureL4Port(ctx, transaction, svc); err != nil { return err } - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } - - if err := transaction.CreateService(ctx, service); err != nil { - return fmt.Errorf("failed to create service: %w", err) - } - - return nil - }) -} - -// persistNewEphemeralService creates an ephemeral service inside a single transaction -// that also enforces the duplicate and per-peer limit checks atomically. -// The count and exists queries use FOR UPDATE locking to serialize concurrent creates -// for the same peer, preventing the per-peer limit from being bypassed. -func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { - return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - // Lock the peer row to serialize concurrent creates for the same peer. - // Without this, when no ephemeral rows exist yet, FOR UPDATE on the services - // table returns no rows and acquires no locks, allowing concurrent inserts - // to bypass the per-peer limit. - if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil { - return fmt.Errorf("lock peer row: %w", err) - } - - exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain) - if err != nil { - return fmt.Errorf("check existing expose: %w", err) - } - if exists { - return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") - } - - count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID) - if err != nil { - return fmt.Errorf("count peer exposes: %w", err) - } - if count >= int64(maxExposesPerPeer) { - return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) - } - - if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + if err := m.checkPortConflict(ctx, transaction, svc); err != nil { return err } @@ -261,11 +267,155 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee }) } +// ensureL4Port auto-assigns a listen port when needed and validates cluster support. +func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error { + if !service.IsL4Protocol(svc.Mode) { + return nil + } + customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster) + if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) { + if svc.Source != service.SourceEphemeral { + return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster) + } + svc.ListenPort = 0 + } + if svc.ListenPort == 0 { + port, err := m.assignPort(ctx, tx, svc.ProxyCluster) + if err != nil { + return err + } + svc.ListenPort = port + svc.PortAutoAssigned = true + } + return nil +} + +// checkPortConflict rejects L4 services that would conflict on the same listener. +// For TCP/UDP: unique per cluster+protocol+port. +// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports). +// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked: +// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener. +func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error { + if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 { + return nil + } + + existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort) + if err != nil { + return fmt.Errorf("query port conflicts: %w", err) + } + for _, s := range existing { + if s.ID == svc.ID { + continue + } + // TLS services on the same port are allowed if they have different domains (SNI routing) + if svc.Mode == service.ModeTLS && s.Domain != svc.Domain { + continue + } + return status.Errorf(status.AlreadyExists, + "%s port %d is already in use by service %q on cluster %s", + svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster) + } + + return nil +} + +// assignPort picks a random available port on the cluster within the auto-assign range. +func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) { + services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster) + if err != nil { + return 0, fmt.Errorf("query cluster ports: %w", err) + } + + occupied := make(map[uint16]struct{}, len(services)) + for _, s := range services { + if s.ListenPort > 0 { + occupied[s.ListenPort] = struct{}{} + } + } + + portRange := int(autoAssignPortMax-autoAssignPortMin) + 1 + for range 100 { + port := autoAssignPortMin + uint16(rand.IntN(portRange)) + if _, taken := occupied[port]; !taken { + return port, nil + } + } + + for port := autoAssignPortMin; port <= autoAssignPortMax; port++ { + if _, taken := occupied[port]; !taken { + return port, nil + } + } + + return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster) +} + +// persistNewEphemeralService creates an ephemeral service inside a single transaction +// that also enforces the duplicate and per-peer limit checks atomically. +// The count and exists queries use FOR UPDATE locking to serialize concurrent creates +// for the same peer, preventing the per-peer limit from being bypassed. +func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { + return err + } + + if err := m.ensureL4Port(ctx, transaction, svc); err != nil { + return err + } + + if err := m.checkPortConflict(ctx, transaction, svc); err != nil { + return err + } + + if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { + return err + } + + if err := transaction.CreateService(ctx, svc); err != nil { + return fmt.Errorf("create service: %w", err) + } + + return nil + }) +} + +func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error { + // Lock the peer row to serialize concurrent creates for the same peer. + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil { + return fmt.Errorf("lock peer row: %w", err) + } + + exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain) + if err != nil { + return fmt.Errorf("check existing expose: %w", err) + } + if exists { + return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") + } + + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { + return err + } + + count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID) + if err != nil { + return fmt.Errorf("count peer exposes: %w", err) + } + if count >= int64(maxExposesPerPeer) { + return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) + } + + return nil +} + +// checkDomainAvailable checks that no other service already uses this domain. func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error { existingService, err := transaction.GetServiceByDomain(ctx, domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { - return fmt.Errorf("failed to check existing service: %w", err) + return fmt.Errorf("check existing service: %w", err) } return nil } @@ -322,6 +472,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se return err } + if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { + return err + } + updateInfo.oldCluster = existingService.ProxyCluster updateInfo.domainChanged = existingService.Domain != service.Domain @@ -335,12 +489,18 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se m.preserveExistingAuthSecrets(service, existingService) m.preserveServiceMetadata(service, existingService) + m.preserveListenPort(service, existingService) updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled + if err := m.ensureL4Port(ctx, transaction, service); err != nil { + return err + } + if err := m.checkPortConflict(ctx, transaction, service); err != nil { + return err + } if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { return err } - if err := transaction.UpdateService(ctx, service); err != nil { return fmt.Errorf("update service: %w", err) } @@ -351,23 +511,39 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se return &updateInfo, err } -func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error { - if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil { +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { + if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { return err } if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) + newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain) + log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) } else { - service.ProxyCluster = newCluster + svc.ProxyCluster = newCluster } } return nil } +// validateProtocolChange rejects mode changes on update. +// Only empty<->HTTP is allowed; all other transitions are rejected. +func validateProtocolChange(oldMode, newMode string) error { + if newMode == "" || newMode == oldMode { + return nil + } + if isHTTPFamily(oldMode) && isHTTPFamily(newMode) { + return nil + } + return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode) +} + +func isHTTPFamily(mode string) bool { + return mode == "" || mode == "http" +} + func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) { if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && @@ -388,6 +564,13 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv service.SessionPublicKey = existingService.SessionPublicKey } +func (m *Manager) preserveListenPort(svc, existing *service.Service) { + if existing.ListenPort > 0 && svc.ListenPort == 0 { + svc.ListenPort = existing.ListenPort + svc.PortAutoAssigned = existing.PortAutoAssigned + } +} + func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) { oidcCfg := m.proxyController.GetOIDCValidationConfig() @@ -675,6 +858,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group") } +func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) { + return m.buildRandomDomain(serviceName) +} + // CreateServiceFromPeer creates a service initiated by a peer expose request. // It validates the request, checks expose permissions, enforces the per-peer limit, // creates the service, and tracks it for TTL-based reaping. @@ -696,9 +883,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s svc.Source = service.SourceEphemeral if svc.Domain == "" { - domain, err := m.buildRandomDomain(svc.Name) + domain, err := m.resolveDefaultDomain(svc.Name) if err != nil { - return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err) + return nil, err } svc.Domain = domain } @@ -739,10 +926,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) + serviceURL := "https://" + svc.Domain + if service.IsL4Protocol(svc.Mode) { + serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort) + } + return &service.ExposeServiceResponse{ - ServiceName: svc.Name, - ServiceURL: "https://" + svc.Domain, - Domain: svc.Domain, + ServiceName: svc.Name, + ServiceURL: serviceURL, + Domain: svc.Domain, + PortAutoAssigned: svc.PortAutoAssigned, }, nil } @@ -761,64 +954,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr return groupIDs, nil } -func (m *Manager) buildRandomDomain(name string) (string, error) { +func (m *Manager) getDefaultClusterDomain() (string, error) { if m.clusterDeriver == nil { - return "", fmt.Errorf("unable to get random domain") + return "", fmt.Errorf("unable to get cluster domain") } clusterDomains := m.clusterDeriver.GetClusterDomains() if len(clusterDomains) == 0 { - return "", fmt.Errorf("no cluster domains found for service %s", name) + return "", fmt.Errorf("no cluster domains available") } - index := rand.IntN(len(clusterDomains)) - domain := name + "." + clusterDomains[index] - return domain, nil + return clusterDomains[rand.IntN(len(clusterDomains))], nil +} + +func (m *Manager) buildRandomDomain(name string) (string, error) { + domain, err := m.getDefaultClusterDomain() + if err != nil { + return "", err + } + return name + "." + domain, nil } // RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service. -func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { - return m.store.RenewEphemeralService(ctx, accountID, peerID, domain) +func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID) } // StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB. -func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { - if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil { - log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err) +func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err) return err } return nil } -// deleteServiceFromPeer deletes a peer-initiated service identified by domain. +// deleteServiceFromPeer deletes a peer-initiated service identified by service ID. // When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed. -func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error { - svc, err := m.lookupPeerService(ctx, accountID, peerID, domain) - if err != nil { - return err - } - +func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error { activityCode := activity.PeerServiceUnexposed if expired { activityCode = activity.PeerServiceExposeExpired } - return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode) -} - -// lookupPeerService finds a peer-initiated service by domain and validates ownership. -func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) { - svc, err := m.store.GetServiceByDomain(ctx, domain) - if err != nil { - return nil, err - } - - if svc.Source != service.SourceEphemeral { - return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose") - } - - if svc.SourcePeer != peerID { - return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer") - } - - return svc, nil + return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode) } func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index ba4e1c805..d23c91017 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -803,8 +803,8 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -826,9 +826,9 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, _ := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 80, - Protocol: "http", - Domain: "example.com", + Port: 80, + Mode: "http", + Domain: "example.com", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -847,8 +847,8 @@ func TestCreateServiceFromPeer(t *testing.T) { require.NoError(t, err) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -860,8 +860,8 @@ func TestCreateServiceFromPeer(t *testing.T) { mgr, _ := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 0, - Protocol: "http", + Port: 0, + Mode: "http", } _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) @@ -878,62 +878,52 @@ func TestExposeServiceRequestValidate(t *testing.T) { }{ { name: "valid http request", - req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"}, wantErr: "", }, { - name: "valid https request with pin", - req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"}, - wantErr: "", + name: "https mode rejected", + req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"}, + wantErr: "unsupported mode", }, { name: "port zero rejected", - req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"}, wantErr: "port must be between 1 and 65535", }, { - name: "negative port rejected", - req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"}, - wantErr: "port must be between 1 and 65535", - }, - { - name: "port above 65535 rejected", - req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"}, - wantErr: "port must be between 1 and 65535", - }, - { - name: "unsupported protocol", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"}, - wantErr: "unsupported protocol", + name: "unsupported mode", + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"}, + wantErr: "unsupported mode", }, { name: "invalid pin format", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"}, wantErr: "invalid pin", }, { name: "pin too short", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"}, wantErr: "invalid pin", }, { name: "valid 6-digit pin", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"}, wantErr: "", }, { name: "empty user group name", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}}, wantErr: "user group name cannot be empty", }, { name: "invalid name prefix", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"}, wantErr: "invalid name prefix", }, { name: "valid name prefix", - req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"}, + req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"}, wantErr: "", }, } @@ -966,14 +956,14 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { // First create a service req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - // Delete by domain using unexported method - err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false) require.NoError(t, err) // Verify service is deleted @@ -982,16 +972,17 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { }) t.Run("expire uses correct activity", func(t *testing.T) { - mgr, _ := setupIntegrationTest(t) + mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true) require.NoError(t, err) }) } @@ -1003,13 +994,14 @@ func TestStopServiceFromPeer(t *testing.T) { mgr, testStore := setupIntegrationTest(t) req := &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) _, err = testStore.GetServiceByDomain(ctx, resp.Domain) @@ -1022,8 +1014,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { mgr, testStore := setupIntegrationTest(t) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) @@ -1042,8 +1034,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete") _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 9090, - Protocol: "http", + Port: 9090, + Mode: "http", }) assert.NoError(t, err, "new expose should succeed after API delete") } @@ -1054,8 +1046,8 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) { for i := range 3 { _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080 + i, - Protocol: "http", + Port: uint16(8080 + i), + Mode: "http", }) require.NoError(t, err) } @@ -1076,21 +1068,22 @@ func TestRenewServiceFromPeer(t *testing.T) { ctx := context.Background() t.Run("renews tracked expose", func(t *testing.T) { - mgr, _ := setupIntegrationTest(t) + mgr, testStore := setupIntegrationTest(t) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", }) require.NoError(t, err) - err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + svcID := resolveServiceIDByDomain(t, testStore, resp.Domain) + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID) require.NoError(t, err) }) t.Run("fails for untracked domain", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com") + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id") require.Error(t, err) }) } @@ -1191,3 +1184,33 @@ func TestDeleteService_DeletesTargets(t *testing.T) { require.NoError(t, err) assert.Len(t, targets, 0, "All targets should be deleted when service is deleted") } + +func TestValidateProtocolChange(t *testing.T) { + tests := []struct { + name string + oldP string + newP string + wantErr bool + }{ + {"empty to http", "", "http", false}, + {"http to http", "http", "http", false}, + {"same protocol", "tcp", "tcp", false}, + {"empty new proto", "tcp", "", false}, + {"http to tcp", "http", "tcp", true}, + {"tcp to udp", "tcp", "udp", true}, + {"tls to http", "tls", "http", true}, + {"udp to tls", "udp", "tls", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateProtocolChange(tt.oldP, tt.newP) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot change mode") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index bfad7fe9a..623284404 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -34,6 +34,7 @@ const ( ) type Status string +type TargetType string const ( StatusPending Status = "pending" @@ -43,34 +44,36 @@ const ( StatusCertificateFailed Status = "certificate_failed" StatusError Status = "error" - TargetTypePeer = "peer" - TargetTypeHost = "host" - TargetTypeDomain = "domain" - TargetTypeSubnet = "subnet" + TargetTypePeer TargetType = "peer" + TargetTypeHost TargetType = "host" + TargetTypeDomain TargetType = "domain" + TargetTypeSubnet TargetType = "subnet" SourcePermanent = "permanent" SourceEphemeral = "ephemeral" ) type TargetOptions struct { - SkipTLSVerify bool `json:"skip_tls_verify"` - RequestTimeout time.Duration `json:"request_timeout,omitempty"` - PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` - CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` + SkipTLSVerify bool `json:"skip_tls_verify"` + RequestTimeout time.Duration `json:"request_timeout,omitempty"` + SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"` + PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` + CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` } type Target struct { - ID uint `gorm:"primaryKey" json:"-"` - AccountID string `gorm:"index:idx_target_account;not null" json:"-"` - ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` - Path *string `json:"path,omitempty"` - Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored - Port int `gorm:"index:idx_target_port" json:"port"` - Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` - TargetId string `gorm:"index:idx_target_id" json:"target_id"` - TargetType string `gorm:"index:idx_target_type" json:"target_type"` - Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` - Options TargetOptions `gorm:"embedded" json:"options"` + ID uint `gorm:"primaryKey" json:"-"` + AccountID string `gorm:"index:idx_target_account;not null" json:"-"` + ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` + Path *string `json:"path,omitempty"` + Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored + Port uint16 `gorm:"index:idx_target_port" json:"port"` + Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` + TargetId string `gorm:"index:idx_target_id" json:"target_id"` + TargetType TargetType `gorm:"index:idx_target_type" json:"target_type"` + Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` + Options TargetOptions `gorm:"embedded" json:"options"` + ProxyProtocol bool `json:"proxy_protocol"` } type PasswordAuthConfig struct { @@ -146,23 +149,10 @@ type Service struct { SessionPublicKey string `gorm:"column:session_public_key"` Source string `gorm:"default:'permanent';index:idx_service_source_peer"` SourcePeer string `gorm:"index:idx_service_source_peer"` -} - -func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service { - for _, target := range targets { - target.AccountID = accountID - } - - s := &Service{ - AccountID: accountID, - Name: name, - Domain: domain, - ProxyCluster: proxyCluster, - Targets: targets, - Enabled: enabled, - } - s.InitNewRecord() - return s + // Mode determines the service type: "http", "tcp", "udp", or "tls". + Mode string `gorm:"default:'http'"` + ListenPort uint16 + PortAutoAssigned bool } // InitNewRecord generates a new unique ID and resets metadata for a newly created @@ -177,21 +167,17 @@ func (s *Service) InitNewRecord() { } func (s *Service) ToAPIResponse() *api.Service { - s.Auth.ClearSecrets() - authConfig := api.ServiceAuthConfig{} if s.Auth.PasswordAuth != nil { authConfig.PasswordAuth = &api.PasswordAuthConfig{ - Enabled: s.Auth.PasswordAuth.Enabled, - Password: s.Auth.PasswordAuth.Password, + Enabled: s.Auth.PasswordAuth.Enabled, } } if s.Auth.PinAuth != nil { authConfig.PinAuth = &api.PINAuthConfig{ Enabled: s.Auth.PinAuth.Enabled, - Pin: s.Auth.PinAuth.Pin, } } @@ -208,13 +194,18 @@ func (s *Service) ToAPIResponse() *api.Service { st := api.ServiceTarget{ Path: target.Path, Host: &target.Host, - Port: target.Port, + Port: int(target.Port), Protocol: api.ServiceTargetProtocol(target.Protocol), TargetId: target.TargetId, TargetType: api.ServiceTargetTargetType(target.TargetType), Enabled: target.Enabled, } - st.Options = targetOptionsToAPI(target.Options) + opts := targetOptionsToAPI(target.Options) + if opts == nil { + opts = &api.ServiceTargetOptions{} + } + opts.ProxyProtocol = &target.ProxyProtocol + st.Options = opts apiTargets = append(apiTargets, st) } @@ -227,6 +218,9 @@ func (s *Service) ToAPIResponse() *api.Service { meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt } + mode := api.ServiceMode(s.Mode) + listenPort := int(s.ListenPort) + resp := &api.Service{ Id: s.ID, Name: s.Name, @@ -237,6 +231,9 @@ func (s *Service) ToAPIResponse() *api.Service { RewriteRedirects: &s.RewriteRedirects, Auth: authConfig, Meta: meta, + Mode: &mode, + ListenPort: &listenPort, + PortAutoAssigned: &s.PortAutoAssigned, } if s.ProxyCluster != "" { @@ -247,37 +244,7 @@ func (s *Service) ToAPIResponse() *api.Service { } func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping { - pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) - for _, target := range s.Targets { - if !target.Enabled { - continue - } - - // TODO: Make path prefix stripping configurable per-target. - // Currently the matching prefix is baked into the target URL path, - // so the proxy strips-then-re-adds it (effectively a no-op). - targetURL := url.URL{ - Scheme: target.Protocol, - Host: target.Host, - Path: "/", // TODO: support service path - } - if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { - targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port)) - } - - path := "/" - if target.Path != nil { - path = *target.Path - } - - pm := &proto.PathMapping{ - Path: path, - Target: targetURL.String(), - } - - pm.Options = targetOptionsToProto(target.Options) - pathMappings = append(pathMappings, pm) - } + pathMappings := s.buildPathMappings() auth := &proto.Authentication{ SessionKey: s.SessionPublicKey, @@ -306,9 +273,58 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf AccountId: s.AccountID, PassHostHeader: s.PassHostHeader, RewriteRedirects: s.RewriteRedirects, + Mode: s.Mode, + ListenPort: int32(s.ListenPort), //nolint:gosec } } +// buildPathMappings constructs PathMapping entries from targets. +// For HTTP/HTTPS, each target becomes a path-based route with a full URL. +// For L4/TLS, a single target maps to a host:port address. +func (s *Service) buildPathMappings() []*proto.PathMapping { + pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) + for _, target := range s.Targets { + if !target.Enabled { + continue + } + + if IsL4Protocol(s.Mode) { + pm := &proto.PathMapping{ + Target: net.JoinHostPort(target.Host, strconv.FormatUint(uint64(target.Port), 10)), + } + opts := l4TargetOptionsToProto(target) + if opts != nil { + pm.Options = opts + } + pathMappings = append(pathMappings, pm) + continue + } + + // HTTP/HTTPS: build full URL + targetURL := url.URL{ + Scheme: target.Protocol, + Host: target.Host, + Path: "/", + } + if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { + targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10)) + } + + path := "/" + if target.Path != nil { + path = *target.Path + } + + pm := &proto.PathMapping{ + Path: path, + Target: targetURL.String(), + } + pm.Options = targetOptionsToProto(target.Options) + pathMappings = append(pathMappings, pm) + } + return pathMappings +} + func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { switch op { case Create: @@ -325,8 +341,8 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { // isDefaultPort reports whether port is the standard default for the given scheme // (443 for https, 80 for http). -func isDefaultPort(scheme string, port int) bool { - return (scheme == "https" && port == 443) || (scheme == "http" && port == 80) +func isDefaultPort(scheme string, port uint16) bool { + return (scheme == TargetProtoHTTPS && port == 443) || (scheme == TargetProtoHTTP && port == 80) } // PathRewriteMode controls how the request path is rewritten before forwarding. @@ -346,7 +362,7 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode { } func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { - if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { + if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { return nil } apiOpts := &api.ServiceTargetOptions{} @@ -357,6 +373,10 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { s := opts.RequestTimeout.String() apiOpts.RequestTimeout = &s } + if opts.SessionIdleTimeout != 0 { + s := opts.SessionIdleTimeout.String() + apiOpts.SessionIdleTimeout = &s + } if opts.PathRewrite != "" { pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite) apiOpts.PathRewrite = &pr @@ -382,6 +402,23 @@ func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions { return popts } +// l4TargetOptionsToProto converts L4-relevant target options to proto. +func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions { + if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 { + return nil + } + opts := &proto.PathTargetOptions{ + ProxyProtocol: target.ProxyProtocol, + } + if target.Options.RequestTimeout > 0 { + opts.RequestTimeout = durationpb.New(target.Options.RequestTimeout) + } + if target.Options.SessionIdleTimeout > 0 { + opts.SessionIdleTimeout = durationpb.New(target.Options.SessionIdleTimeout) + } + return opts +} + func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) { var opts TargetOptions if o.SkipTlsVerify != nil { @@ -394,6 +431,13 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, } opts.RequestTimeout = d } + if o.SessionIdleTimeout != nil { + d, err := time.ParseDuration(*o.SessionIdleTimeout) + if err != nil { + return opts, fmt.Errorf("target %d: parse session_idle_timeout %q: %w", idx, *o.SessionIdleTimeout, err) + } + opts.SessionIdleTimeout = d + } if o.PathRewrite != nil { opts.PathRewrite = PathRewriteMode(*o.PathRewrite) } @@ -408,15 +452,49 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro s.Domain = req.Domain s.AccountID = accountID - targets := make([]*Target, 0, len(req.Targets)) - for i, apiTarget := range req.Targets { + if req.Mode != nil { + s.Mode = string(*req.Mode) + } + if req.ListenPort != nil { + s.ListenPort = uint16(*req.ListenPort) //nolint:gosec + } + + targets, err := targetsFromAPI(accountID, req.Targets) + if err != nil { + return err + } + s.Targets = targets + s.Enabled = req.Enabled + + if req.PassHostHeader != nil { + s.PassHostHeader = *req.PassHostHeader + } + if req.RewriteRedirects != nil { + s.RewriteRedirects = *req.RewriteRedirects + } + + if req.Auth != nil { + s.Auth = authFromAPI(req.Auth) + } + + return nil +} + +func targetsFromAPI(accountID string, apiTargetsPtr *[]api.ServiceTarget) ([]*Target, error) { + var apiTargets []api.ServiceTarget + if apiTargetsPtr != nil { + apiTargets = *apiTargetsPtr + } + + targets := make([]*Target, 0, len(apiTargets)) + for i, apiTarget := range apiTargets { target := &Target{ AccountID: accountID, Path: apiTarget.Path, - Port: apiTarget.Port, + Port: uint16(apiTarget.Port), //nolint:gosec // validated by API layer Protocol: string(apiTarget.Protocol), TargetId: apiTarget.TargetId, - TargetType: string(apiTarget.TargetType), + TargetType: TargetType(apiTarget.TargetType), Enabled: apiTarget.Enabled, } if apiTarget.Host != nil { @@ -425,49 +503,42 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro if apiTarget.Options != nil { opts, err := targetOptionsFromAPI(i, apiTarget.Options) if err != nil { - return err + return nil, err } target.Options = opts + if apiTarget.Options.ProxyProtocol != nil { + target.ProxyProtocol = *apiTarget.Options.ProxyProtocol + } } targets = append(targets, target) } - s.Targets = targets + return targets, nil +} - s.Enabled = req.Enabled - - if req.PassHostHeader != nil { - s.PassHostHeader = *req.PassHostHeader - } - - if req.RewriteRedirects != nil { - s.RewriteRedirects = *req.RewriteRedirects - } - - if req.Auth.PasswordAuth != nil { - s.Auth.PasswordAuth = &PasswordAuthConfig{ - Enabled: req.Auth.PasswordAuth.Enabled, - Password: req.Auth.PasswordAuth.Password, +func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig { + var auth AuthConfig + if reqAuth.PasswordAuth != nil { + auth.PasswordAuth = &PasswordAuthConfig{ + Enabled: reqAuth.PasswordAuth.Enabled, + Password: reqAuth.PasswordAuth.Password, } } - - if req.Auth.PinAuth != nil { - s.Auth.PinAuth = &PINAuthConfig{ - Enabled: req.Auth.PinAuth.Enabled, - Pin: req.Auth.PinAuth.Pin, + if reqAuth.PinAuth != nil { + auth.PinAuth = &PINAuthConfig{ + Enabled: reqAuth.PinAuth.Enabled, + Pin: reqAuth.PinAuth.Pin, } } - - if req.Auth.BearerAuth != nil { + if reqAuth.BearerAuth != nil { bearerAuth := &BearerAuthConfig{ - Enabled: req.Auth.BearerAuth.Enabled, + Enabled: reqAuth.BearerAuth.Enabled, } - if req.Auth.BearerAuth.DistributionGroups != nil { - bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups + if reqAuth.BearerAuth.DistributionGroups != nil { + bearerAuth.DistributionGroups = *reqAuth.BearerAuth.DistributionGroups } - s.Auth.BearerAuth = bearerAuth + auth.BearerAuth = bearerAuth } - - return nil + return auth } func (s *Service) Validate() error { @@ -478,14 +549,69 @@ func (s *Service) Validate() error { return errors.New("service name exceeds maximum length of 255 characters") } - if s.Domain == "" { - return errors.New("service domain is required") - } - if len(s.Targets) == 0 { return errors.New("at least one target is required") } + if s.Mode == "" { + s.Mode = ModeHTTP + } + + switch s.Mode { + case ModeHTTP: + return s.validateHTTPMode() + case ModeTCP, ModeUDP: + return s.validateTCPUDPMode() + case ModeTLS: + return s.validateTLSMode() + default: + return fmt.Errorf("unsupported mode %q", s.Mode) + } +} + +func (s *Service) validateHTTPMode() error { + if s.Domain == "" { + return errors.New("service domain is required") + } + if s.ListenPort != 0 { + return errors.New("listen_port is not supported for HTTP services") + } + return s.validateHTTPTargets() +} + +func (s *Service) validateTCPUDPMode() error { + if s.Domain == "" { + return errors.New("domain is required for TCP/UDP services (used for cluster derivation)") + } + if s.isAuthEnabled() { + return errors.New("auth is not supported for TCP/UDP services") + } + if len(s.Targets) != 1 { + return errors.New("TCP/UDP services must have exactly one target") + } + if s.Mode == ModeUDP && s.Targets[0].ProxyProtocol { + return errors.New("proxy_protocol is not supported for UDP services") + } + return s.validateL4Target(s.Targets[0]) +} + +func (s *Service) validateTLSMode() error { + if s.Domain == "" { + return errors.New("domain is required for TLS services (used for SNI matching)") + } + if s.isAuthEnabled() { + return errors.New("auth is not supported for TLS services") + } + if s.ListenPort == 0 { + return errors.New("listen_port is required for TLS services") + } + if len(s.Targets) != 1 { + return errors.New("TLS services must have exactly one target") + } + return s.validateL4Target(s.Targets[0]) +} + +func (s *Service) validateHTTPTargets() error { for i, target := range s.Targets { switch target.TargetType { case TargetTypePeer, TargetTypeHost, TargetTypeDomain: @@ -500,6 +626,9 @@ func (s *Service) Validate() error { if target.TargetId == "" { return fmt.Errorf("target %d has empty target_id", i) } + if target.ProxyProtocol { + return fmt.Errorf("target %d: proxy_protocol is not supported for HTTP services", i) + } if err := validateTargetOptions(i, &target.Options); err != nil { return err } @@ -508,11 +637,62 @@ func (s *Service) Validate() error { return nil } +func (s *Service) validateL4Target(target *Target) error { + if target.Port == 0 { + return errors.New("target port is required for L4 services") + } + if target.TargetId == "" { + return errors.New("target_id is required for L4 services") + } + switch target.TargetType { + case TargetTypePeer, TargetTypeHost: + // OK + case TargetTypeSubnet: + if target.Host == "" { + return errors.New("target host is required for subnet targets") + } + default: + return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType) + } + if target.Path != nil && *target.Path != "" && *target.Path != "/" { + return errors.New("path is not supported for L4 services") + } + return nil +} + +// Service mode constants. const ( - maxRequestTimeout = 5 * time.Minute - maxCustomHeaders = 16 - maxHeaderKeyLen = 128 - maxHeaderValueLen = 4096 + ModeHTTP = "http" + ModeTCP = "tcp" + ModeUDP = "udp" + ModeTLS = "tls" +) + +// Target protocol constants (URL scheme for backend connections). +const ( + TargetProtoHTTP = "http" + TargetProtoHTTPS = "https" + TargetProtoTCP = "tcp" + TargetProtoUDP = "udp" +) + +// IsL4Protocol returns true if the mode requires port-based routing (TCP, UDP, or TLS). +func IsL4Protocol(mode string) bool { + return mode == ModeTCP || mode == ModeUDP || mode == ModeTLS +} + +// IsPortBasedProtocol returns true if the mode relies on dedicated port allocation. +// TLS is excluded because it uses SNI routing and can share ports with other TLS services. +func IsPortBasedProtocol(mode string) bool { + return mode == ModeTCP || mode == ModeUDP +} + +const ( + maxRequestTimeout = 5 * time.Minute + maxSessionIdleTimeout = 10 * time.Minute + maxCustomHeaders = 16 + maxHeaderKeyLen = 128 + maxHeaderValueLen = 4096 ) // httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition. @@ -560,6 +740,15 @@ func validateTargetOptions(idx int, opts *TargetOptions) error { } } + if opts.SessionIdleTimeout != 0 { + if opts.SessionIdleTimeout <= 0 { + return fmt.Errorf("target %d: session_idle_timeout must be positive", idx) + } + if opts.SessionIdleTimeout > maxSessionIdleTimeout { + return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout) + } + } + if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil { return err } @@ -608,17 +797,49 @@ func containsCRLF(s string) bool { } func (s *Service) EventMeta() map[string]any { - return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()} + meta := map[string]any{ + "name": s.Name, + "domain": s.Domain, + "proxy_cluster": s.ProxyCluster, + "source": s.Source, + "auth": s.isAuthEnabled(), + "mode": s.Mode, + } + + if s.ListenPort != 0 { + meta["listen_port"] = s.ListenPort + } + + if len(s.Targets) > 0 { + t := s.Targets[0] + if t.ProxyProtocol { + meta["proxy_protocol"] = true + } + if t.Options.RequestTimeout != 0 { + meta["request_timeout"] = t.Options.RequestTimeout.String() + } + if t.Options.SessionIdleTimeout != 0 { + meta["session_idle_timeout"] = t.Options.SessionIdleTimeout.String() + } + } + + return meta } func (s *Service) isAuthEnabled() bool { - return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil + return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) || + (s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) || + (s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) } func (s *Service) Copy() *Service { targets := make([]*Target, len(s.Targets)) for i, target := range s.Targets { targetCopy := *target + if target.Path != nil { + p := *target.Path + targetCopy.Path = &p + } if len(target.Options.CustomHeaders) > 0 { targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders)) for k, v := range target.Options.CustomHeaders { @@ -628,6 +849,24 @@ func (s *Service) Copy() *Service { targets[i] = &targetCopy } + authCopy := s.Auth + if s.Auth.PasswordAuth != nil { + pa := *s.Auth.PasswordAuth + authCopy.PasswordAuth = &pa + } + if s.Auth.PinAuth != nil { + pa := *s.Auth.PinAuth + authCopy.PinAuth = &pa + } + if s.Auth.BearerAuth != nil { + ba := *s.Auth.BearerAuth + if len(s.Auth.BearerAuth.DistributionGroups) > 0 { + ba.DistributionGroups = make([]string, len(s.Auth.BearerAuth.DistributionGroups)) + copy(ba.DistributionGroups, s.Auth.BearerAuth.DistributionGroups) + } + authCopy.BearerAuth = &ba + } + return &Service{ ID: s.ID, AccountID: s.AccountID, @@ -638,12 +877,15 @@ func (s *Service) Copy() *Service { Enabled: s.Enabled, PassHostHeader: s.PassHostHeader, RewriteRedirects: s.RewriteRedirects, - Auth: s.Auth, + Auth: authCopy, Meta: s.Meta, SessionPrivateKey: s.SessionPrivateKey, SessionPublicKey: s.SessionPublicKey, Source: s.Source, SourcePeer: s.SourcePeer, + Mode: s.Mode, + ListenPort: s.ListenPort, + PortAutoAssigned: s.PortAutoAssigned, } } @@ -688,12 +930,16 @@ var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`) // ExposeServiceRequest contains the parameters for creating a peer-initiated expose service. type ExposeServiceRequest struct { NamePrefix string - Port int - Protocol string - Domain string - Pin string - Password string - UserGroups []string + Port uint16 + Mode string + // TargetProtocol is the protocol used to connect to the peer backend. + // For HTTP mode: "http" (default) or "https". For L4 modes: "tcp" or "udp". + TargetProtocol string + Domain string + Pin string + Password string + UserGroups []string + ListenPort uint16 } // Validate checks all fields of the expose request. @@ -702,12 +948,20 @@ func (r *ExposeServiceRequest) Validate() error { return errors.New("request cannot be nil") } - if r.Port < 1 || r.Port > 65535 { + if r.Port == 0 { return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port) } - if r.Protocol != "http" && r.Protocol != "https" { - return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol) + switch r.Mode { + case ModeHTTP, ModeTCP, ModeUDP, ModeTLS: + default: + return fmt.Errorf("unsupported mode %q", r.Mode) + } + + if IsL4Protocol(r.Mode) { + if r.Pin != "" || r.Password != "" || len(r.UserGroups) > 0 { + return fmt.Errorf("authentication is not supported for %s mode", r.Mode) + } } if r.Pin != "" && !pinRegexp.MatchString(r.Pin) { @@ -729,55 +983,79 @@ func (r *ExposeServiceRequest) Validate() error { // ToService builds a Service from the expose request. func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service { - service := &Service{ + svc := &Service{ AccountID: accountID, Name: serviceName, + Mode: r.Mode, Enabled: true, - Targets: []*Target{ - { - AccountID: accountID, - Port: r.Port, - Protocol: r.Protocol, - TargetId: peerID, - TargetType: TargetTypePeer, - Enabled: true, - }, + } + + // If domain is empty, CreateServiceFromPeer generates a unique subdomain. + // When explicitly provided, the service name is prepended as a subdomain. + if r.Domain != "" { + svc.Domain = serviceName + "." + r.Domain + } + + if IsL4Protocol(r.Mode) { + svc.ListenPort = r.Port + if r.ListenPort > 0 { + svc.ListenPort = r.ListenPort + } + } + + var targetProto string + switch { + case !IsL4Protocol(r.Mode): + targetProto = TargetProtoHTTP + if r.TargetProtocol != "" { + targetProto = r.TargetProtocol + } + case r.Mode == ModeUDP: + targetProto = TargetProtoUDP + default: + targetProto = TargetProtoTCP + } + svc.Targets = []*Target{ + { + AccountID: accountID, + Port: r.Port, + Protocol: targetProto, + TargetId: peerID, + TargetType: TargetTypePeer, + Enabled: true, }, } - if r.Domain != "" { - service.Domain = serviceName + "." + r.Domain - } - if r.Pin != "" { - service.Auth.PinAuth = &PINAuthConfig{ + svc.Auth.PinAuth = &PINAuthConfig{ Enabled: true, Pin: r.Pin, } } if r.Password != "" { - service.Auth.PasswordAuth = &PasswordAuthConfig{ + svc.Auth.PasswordAuth = &PasswordAuthConfig{ Enabled: true, Password: r.Password, } } if len(r.UserGroups) > 0 { - service.Auth.BearerAuth = &BearerAuthConfig{ + svc.Auth.BearerAuth = &BearerAuthConfig{ Enabled: true, DistributionGroups: r.UserGroups, } } - return service + return svc } // ExposeServiceResponse contains the result of a successful peer expose creation. type ExposeServiceResponse struct { - ServiceName string - ServiceURL string - Domain string + ServiceName string + ServiceURL string + Domain string + PortAutoAssigned bool } // GenerateExposeName generates a random service name for peer-exposed services. diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index 79c98fc14..a8a8ae5d6 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -44,7 +44,7 @@ func TestValidate_EmptyDomain(t *testing.T) { func TestValidate_NoTargets(t *testing.T) { rp := validProxy() rp.Targets = nil - assert.ErrorContains(t, rp.Validate(), "at least one target") + assert.ErrorContains(t, rp.Validate(), "at least one target is required") } func TestValidate_EmptyTargetId(t *testing.T) { @@ -273,7 +273,7 @@ func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) { func TestIsDefaultPort(t *testing.T) { tests := []struct { scheme string - port int + port uint16 want bool }{ {"http", 80, true}, @@ -299,7 +299,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) { name string protocol string host string - port int + port uint16 wantTarget string }{ { @@ -645,8 +645,8 @@ func TestGenerateExposeName(t *testing.T) { func TestExposeServiceRequest_ToService(t *testing.T) { t.Run("basic HTTP service", func(t *testing.T) { req := &ExposeServiceRequest{ - Port: 8080, - Protocol: "http", + Port: 8080, + Mode: "http", } service := req.ToService("account-1", "peer-1", "mysvc") @@ -658,7 +658,7 @@ func TestExposeServiceRequest_ToService(t *testing.T) { require.Len(t, service.Targets, 1) target := service.Targets[0] - assert.Equal(t, 8080, target.Port) + assert.Equal(t, uint16(8080), target.Port) assert.Equal(t, "http", target.Protocol) assert.Equal(t, "peer-1", target.TargetId) assert.Equal(t, TargetTypePeer, target.TargetType) @@ -730,3 +730,182 @@ func TestExposeServiceRequest_ToService(t *testing.T) { require.NotNil(t, service.Auth.BearerAuth) }) } + +func TestValidate_TLSOnly(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_TLSMissingListenPort(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 0, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "listen_port is required") +} + +func TestValidate_TLSMissingDomain(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "domain is required") +} + +func TestValidate_TCPValid(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_TCPMissingListenPort(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + require.NoError(t, rp.Validate(), "TCP with listen_port=0 is valid (auto-assigned by manager)") +} + +func TestValidate_L4MultipleTargets(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + {TargetId: "peer-2", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "exactly one target") +} + +func TestValidate_L4TargetMissingPort(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 0, Enabled: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "port is required") +} + +func TestValidate_TLSInvalidTargetType(t *testing.T) { + rp := &Service{ + Name: "tls-svc", + Mode: "tls", + Domain: "example.com", + ListenPort: 443, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: "invalid", Protocol: "tcp", Port: 443, Enabled: true}, + }, + } + assert.Error(t, rp.Validate()) +} + +func TestValidate_TLSSubnetValid(t *testing.T) { + rp := &Service{ + Name: "tls-subnet", + Mode: "tls", + Domain: "example.com", + ListenPort: 8443, + Targets: []*Target{ + {TargetId: "subnet-1", TargetType: TargetTypeSubnet, Protocol: "tcp", Port: 443, Host: "10.0.0.5", Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestValidate_HTTPProxyProtocolRejected(t *testing.T) { + rp := validProxy() + rp.Targets[0].ProxyProtocol = true + assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for HTTP") +} + +func TestValidate_UDPProxyProtocolRejected(t *testing.T) { + rp := &Service{ + Name: "udp-svc", + Mode: "udp", + Domain: "cluster.test", + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "udp", Port: 5432, Enabled: true, ProxyProtocol: true}, + }, + } + assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for UDP") +} + +func TestValidate_TCPProxyProtocolAllowed(t *testing.T) { + rp := &Service{ + Name: "tcp-svc", + Mode: "tcp", + Domain: "cluster.test", + ListenPort: 5432, + Targets: []*Target{ + {TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true, ProxyProtocol: true}, + }, + } + require.NoError(t, rp.Validate()) +} + +func TestExposeServiceRequest_Validate_L4RejectsAuth(t *testing.T) { + tests := []struct { + name string + req ExposeServiceRequest + }{ + { + name: "tcp with pin", + req: ExposeServiceRequest{Port: 8080, Mode: "tcp", Pin: "123456"}, + }, + { + name: "udp with password", + req: ExposeServiceRequest{Port: 8080, Mode: "udp", Password: "secret"}, + }, + { + name: "tls with user groups", + req: ExposeServiceRequest{Port: 443, Mode: "tls", UserGroups: []string{"admins"}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.req.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication is not supported") + }) + } +} + +func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) { + req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"} + require.NoError(t, req.Validate()) +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index eb13a15e3..88d37ca80 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 29a8953ac..a32cf6046 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" @@ -211,6 +212,9 @@ func (s *BaseServer) ProxyManager() proxy.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) + s.AfterInit(func(s *BaseServer) { + m.SetClusterCapabilities(s.ServiceProxyController()) + }) return &m }) } diff --git a/management/internals/shared/grpc/expose_service.go b/management/internals/shared/grpc/expose_service.go index c444471b0..1b87f7ede 100644 --- a/management/internals/shared/grpc/expose_service.go +++ b/management/internals/shared/grpc/expose_service.go @@ -2,6 +2,7 @@ package grpc import ( "context" + "fmt" pb "github.com/golang/protobuf/proto" // nolint log "github.com/sirupsen/logrus" @@ -39,23 +40,38 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } + if exposeReq.Port > 65535 { + return nil, status.Errorf(codes.InvalidArgument, "port out of range: %d", exposeReq.Port) + } + if exposeReq.ListenPort > 65535 { + return nil, status.Errorf(codes.InvalidArgument, "listen_port out of range: %d", exposeReq.ListenPort) + } + + mode, err := exposeProtocolToString(exposeReq.Protocol) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{ - NamePrefix: exposeReq.NamePrefix, - Port: int(exposeReq.Port), - Protocol: exposeProtocolToString(exposeReq.Protocol), - Domain: exposeReq.Domain, - Pin: exposeReq.Pin, - Password: exposeReq.Password, - UserGroups: exposeReq.UserGroups, + NamePrefix: exposeReq.NamePrefix, + Port: uint16(exposeReq.Port), //nolint:gosec // validated above + Mode: mode, + TargetProtocol: exposeTargetProtocol(exposeReq.Protocol), + Domain: exposeReq.Domain, + Pin: exposeReq.Pin, + Password: exposeReq.Password, + UserGroups: exposeReq.UserGroups, + ListenPort: uint16(exposeReq.ListenPort), //nolint:gosec // validated above }) if err != nil { return nil, mapExposeError(ctx, err) } return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{ - ServiceName: created.ServiceName, - ServiceUrl: created.ServiceURL, - Domain: created.Domain, + ServiceName: created.ServiceName, + ServiceUrl: created.ServiceURL, + Domain: created.Domain, + PortAutoAssigned: created.PortAutoAssigned, }) } @@ -77,7 +93,12 @@ func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) ( return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil { + serviceID, err := s.resolveServiceID(ctx, renewReq.Domain) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil { return nil, mapExposeError(ctx, err) } @@ -102,7 +123,12 @@ func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (* return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil { + serviceID, err := s.resolveServiceID(ctx, stopReq.Domain) + if err != nil { + return nil, mapExposeError(ctx, err) + } + + if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil { return nil, mapExposeError(ctx, err) } @@ -180,13 +206,46 @@ func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) { s.reverseProxyManager = mgr } -func exposeProtocolToString(p proto.ExposeProtocol) string { +// resolveServiceID looks up the service by its globally unique domain. +func (s *Server) resolveServiceID(ctx context.Context, domain string) (string, error) { + if domain == "" { + return "", status.Errorf(codes.InvalidArgument, "domain is required") + } + + svc, err := s.accountManager.GetStore().GetServiceByDomain(ctx, domain) + if err != nil { + return "", err + } + return svc.ID, nil +} + +func exposeProtocolToString(p proto.ExposeProtocol) (string, error) { switch p { - case proto.ExposeProtocol_EXPOSE_HTTP: - return "http" - case proto.ExposeProtocol_EXPOSE_HTTPS: - return "https" + case proto.ExposeProtocol_EXPOSE_HTTP, proto.ExposeProtocol_EXPOSE_HTTPS: + return "http", nil + case proto.ExposeProtocol_EXPOSE_TCP: + return "tcp", nil + case proto.ExposeProtocol_EXPOSE_UDP: + return "udp", nil + case proto.ExposeProtocol_EXPOSE_TLS: + return "tls", nil default: - return "http" + return "", fmt.Errorf("unsupported expose protocol: %v", p) + } +} + +// exposeTargetProtocol returns the target protocol for the given expose protocol. +// For HTTP mode, this is http or https (the scheme used to connect to the backend). +// For L4 modes, this is tcp or udp (the transport used to connect to the backend). +func exposeTargetProtocol(p proto.ExposeProtocol) string { + switch p { + case proto.ExposeProtocol_EXPOSE_HTTPS: + return rpservice.TargetProtoHTTPS + case proto.ExposeProtocol_EXPOSE_TCP, proto.ExposeProtocol_EXPOSE_TLS: + return rpservice.TargetProtoTCP + case proto.ExposeProtocol_EXPOSE_UDP: + return rpservice.TargetProtoUDP + default: + return rpservice.TargetProtoHTTP } } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index e2d0f1abe..31a0ba0db 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -32,6 +32,7 @@ import ( proxyauth "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/management/proto" + nbstatus "github.com/netbirdio/netbird/shared/management/status" ) type ProxyOIDCConfig struct { @@ -45,12 +46,6 @@ type ProxyOIDCConfig struct { KeysLocation string } -// ClusterInfo contains information about a proxy cluster. -type ClusterInfo struct { - Address string - ConnectedProxies int -} - // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -61,9 +56,9 @@ type ProxyServiceServer struct { // Manager for access logs accessLogManager accesslogs.Manager + mu sync.RWMutex // Manager for reverse proxy operations serviceManager rpservice.Manager - // ProxyController for service updates and cluster management proxyController proxy.Controller @@ -84,23 +79,26 @@ type ProxyServiceServer struct { // Store for PKCE verifiers pkceVerifierStore *PKCEVerifierStore + + cancel context.CancelFunc } const pkceVerifierTTL = 10 * time.Minute // proxyConnection represents a connected proxy type proxyConnection struct { - proxyID string - address string - stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.GetMappingUpdateResponse - ctx context.Context - cancel context.CancelFunc + proxyID string + address string + capabilities *proto.ProxyCapabilities + stream proto.ProxyService_GetMappingUpdateServer + sendChan chan *proto.GetMappingUpdateResponse + ctx context.Context + cancel context.CancelFunc } // NewProxyServiceServer creates a new proxy service server. func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ accessLogManager: accessLogMgr, oidcConfig: oidcConfig, @@ -109,6 +107,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + cancel: cancel, } go s.cleanupStaleProxies(ctx) return s @@ -130,11 +129,22 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { } } +// Close stops background goroutines. +func (s *ProxyServiceServer) Close() { + s.cancel() +} + +// SetServiceManager sets the service manager. Must be called before serving. func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { + s.mu.Lock() + defer s.mu.Unlock() s.serviceManager = manager } +// SetProxyController sets the proxy controller. Must be called before serving. func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) { + s.mu.Lock() + defer s.mu.Unlock() s.proxyController = proxyController } @@ -157,12 +167,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ - proxyID: proxyID, - address: proxyAddress, - stream: stream, - sendChan: make(chan *proto.GetMappingUpdateResponse, 100), - ctx: connCtx, - cancel: cancel, + proxyID: proxyID, + address: proxyAddress, + capabilities: req.GetCapabilities(), + stream: stream, + sendChan: make(chan *proto.GetMappingUpdateResponse, 100), + ctx: connCtx, + cancel: cancel, } s.connectedProxies.Store(proxyID, conn) @@ -231,29 +242,18 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) { } // sendSnapshot sends the initial snapshot of services to the connecting proxy. -// Only services matching the proxy's cluster address are sent. +// Only entries matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { - services, err := s.serviceManager.GetGlobalServices(ctx) - if err != nil { - return fmt.Errorf("get services from store: %w", err) - } - if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") } - var filtered []*rpservice.Service - for _, service := range services { - if !service.Enabled { - continue - } - if service.ProxyCluster == "" || service.ProxyCluster != conn.address { - continue - } - filtered = append(filtered, service) + mappings, err := s.snapshotServiceMappings(ctx, conn) + if err != nil { + return err } - if len(filtered) == 0 { + if len(mappings) == 0 { if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ InitialSyncComplete: true, }); err != nil { @@ -262,9 +262,30 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return nil } - for i, service := range filtered { - // Generate one-time authentication token for each service in the snapshot - // Tokens are not persistent on the proxy, so we need to generate new ones on reconnection + for i, m := range mappings { + if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{m}, + InitialSyncComplete: i == len(mappings)-1, + }); err != nil { + return fmt.Errorf("send proxy mapping: %w", err) + } + } + + return nil +} + +func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) { + services, err := s.serviceManager.GetGlobalServices(ctx) + if err != nil { + return nil, fmt.Errorf("get services from store: %w", err) + } + + var mappings []*proto.ProxyMapping + for _, service := range services { + if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address { + continue + } + token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) if err != nil { log.WithFields(log.Fields{ @@ -274,25 +295,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec continue } - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{ - service.ToProtoMapping( - rpservice.Create, // Initial snapshot, all records are "new" for the proxy. - token, - s.GetOIDCValidationConfig(), - ), - }, - InitialSyncComplete: i == len(filtered)-1, - }); err != nil { - log.WithFields(log.Fields{ - "domain": service.Domain, - "account": service.AccountID, - }).WithError(err).Error("failed to send proxy mapping") - return fmt.Errorf("send proxy mapping: %w", err) - } + m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + mappings = append(mappings, m) } - - return nil + return mappings, nil } // isProxyAddressValid validates a proxy address @@ -305,8 +311,8 @@ func isProxyAddressValid(addr string) bool { func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) { for { select { - case msg := <-conn.sendChan: - if err := conn.stream.Send(msg); err != nil { + case resp := <-conn.sendChan: + if err := conn.stream.Send(resp); err != nil { errChan <- err return } @@ -361,12 +367,12 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes log.Debugf("Broadcasting service update to all connected proxy servers") s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) - msg := s.perProxyMessage(update, conn.proxyID) - if msg == nil { + resp := s.perProxyMessage(update, conn.proxyID) + if resp == nil { return true } select { - case conn.sendChan <- msg: + case conn.sendChan <- resp: log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) @@ -495,9 +501,40 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { Auth: m.Auth, PassHostHeader: m.PassHostHeader, RewriteRedirects: m.RewriteRedirects, + Mode: m.Mode, + ListenPort: m.ListenPort, } } +// ClusterSupportsCustomPorts returns whether any connected proxy in the given +// cluster reports custom port support. Returns nil if no proxy has reported +// capabilities (old proxies that predate the field). +func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool { + if s.proxyController == nil { + return nil + } + + var hasCapabilities bool + for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) { + connVal, ok := s.connectedProxies.Load(pid) + if !ok { + continue + } + conn := connVal.(*proxyConnection) + if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil { + continue + } + if *conn.capabilities.SupportsCustomPorts { + return ptr(true) + } + hasCapabilities = true + } + if hasCapabilities { + return ptr(false) + } + return nil +} + func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { @@ -585,7 +622,7 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic return token, nil } -// SendStatusUpdate handles status updates from proxy clients +// SendStatusUpdate handles status updates from proxy clients. func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { accountID := req.GetAccountId() serviceID := req.GetServiceId() @@ -604,6 +641,17 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required") } + internalStatus := protoStatusToInternal(protoStatus) + + if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { + sErr, isNbErr := nbstatus.FromError(err) + if isNbErr && sErr.Type() == nbstatus.NotFound { + return nil, status.Errorf(codes.NotFound, "service %s not found", serviceID) + } + log.WithContext(ctx).WithError(err).Error("failed to update service status") + return nil, status.Errorf(codes.Internal, "update service status: %v", err) + } + if certificateIssued { if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp") @@ -615,13 +663,6 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se }).Info("Certificate issued timestamp updated") } - internalStatus := protoStatusToInternal(protoStatus) - - if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { - log.WithContext(ctx).WithError(err).Error("failed to update service status") - return nil, status.Errorf(codes.Internal, "update service status: %v", err) - } - log.WithFields(log.Fields{ "service_id": serviceID, "account_id": accountID, @@ -631,7 +672,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se return &proto.SendStatusUpdateResponse{}, nil } -// protoStatusToInternal maps proto status to internal status +// protoStatusToInternal maps proto status to internal service status. func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { switch protoStatus { case proto.ProxyStatus_PROXY_STATUS_PENDING: @@ -1061,3 +1102,5 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user * return fmt.Errorf("user not in allowed groups") } + +func ptr[T any](v T) *T { return &v } diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index b7abb28b6..1a4ea3330 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -53,6 +53,10 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus return nil } +func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { + return ptr(true) +} + func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string { c.mu.Lock() defer c.mu.Unlock() @@ -70,11 +74,17 @@ func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string // registerFakeProxy adds a fake proxy connection to the server's internal maps // and returns the channel where messages will be received. func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse { + return registerFakeProxyWithCaps(s, proxyID, clusterAddr, nil) +} + +// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. +func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { ch := make(chan *proto.GetMappingUpdateResponse, 10) conn := &proxyConnection{ - proxyID: proxyID, - address: clusterAddr, - sendChan: ch, + proxyID: proxyID, + address: clusterAddr, + capabilities: caps, + sendChan: ch, } s.connectedProxies.Store(proxyID, conn) @@ -83,15 +93,29 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan return ch } -func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse { +// drainMapping drains a single ProxyMapping from the channel. +func drainMapping(ch chan *proto.GetMappingUpdateResponse) *proto.ProxyMapping { select { - case msg := <-ch: - return msg + case resp := <-ch: + if len(resp.Mapping) > 0 { + return resp.Mapping[0] + } + return nil case <-time.After(time.Second): return nil } } +// drainEmpty checks if a channel has no message within timeout. +func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool { + select { + case <-ch: + return false + case <-time.After(100 * time.Millisecond): + return true + } +} + func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) @@ -129,10 +153,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { tokens := make([]string, numProxies) for i, ch := range channels { - resp := drainChannel(ch) - require.NotNil(t, resp, "proxy %d should receive a message", i) - require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i) - msg := resp.Mapping[0] + msg := drainMapping(ch) + require.NotNil(t, msg, "proxy %d should receive a message", i) assert.Equal(t, mapping.Domain, msg.Domain) assert.Equal(t, mapping.Id, msg.Id) assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i) @@ -181,16 +203,14 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { s.SendServiceUpdateToCluster(context.Background(), mapping, cluster) - resp1 := drainChannel(ch1) - resp2 := drainChannel(ch2) - require.NotNil(t, resp1) - require.NotNil(t, resp2) - require.Len(t, resp1.Mapping, 1) - require.Len(t, resp2.Mapping, 1) + msg1 := drainMapping(ch1) + msg2 := drainMapping(ch2) + require.NotNil(t, msg1) + require.NotNil(t, msg2) // Delete operations should not generate tokens - assert.Empty(t, resp1.Mapping[0].AuthToken) - assert.Empty(t, resp2.Mapping[0].AuthToken) + assert.Empty(t, msg1.AuthToken) + assert.Empty(t, msg2.AuthToken) } func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { @@ -224,15 +244,10 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { s.SendServiceUpdate(update) - resp1 := drainChannel(ch1) - resp2 := drainChannel(ch2) - require.NotNil(t, resp1) - require.NotNil(t, resp2) - require.Len(t, resp1.Mapping, 1) - require.Len(t, resp2.Mapping, 1) - - msg1 := resp1.Mapping[0] - msg2 := resp2.Mapping[0] + msg1 := drainMapping(ch1) + msg2 := drainMapping(ch2) + require.NotNil(t, msg1) + require.NotNil(t, msg2) assert.NotEmpty(t, msg1.AuthToken) assert.NotEmpty(t, msg2.AuthToken) @@ -324,3 +339,314 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "invalid state signature") } + +func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + + const cluster = "proxy.example.com" + + // Proxy A supports custom ports. + chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + // Proxy B does NOT support custom ports (shared cloud proxy). + chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + ctx := context.Background() + + // TLS passthrough works on all proxies regardless of custom port support. + tlsMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tls", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tls", + ListenPort: 8443, + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster) + + msgA := drainMapping(chA) + msgB := drainMapping(chB) + assert.NotNil(t, msgA, "proxy-a should receive TLS mapping") + assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)") + + // Send an HTTP mapping: both should receive it. + httpMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-http", + AccountId: "account-1", + Domain: "app.example.com", + Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:80"}}, + } + + s.SendServiceUpdateToCluster(ctx, httpMapping, cluster) + + msgA = drainMapping(chA) + msgB = drainMapping(chB) + assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping") + assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping") +} + +func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + + const cluster = "proxy.example.com" + + chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + tlsMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tls", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tls", + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster) + + msg := drainMapping(chShared) + assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support") +} + +// TestServiceModifyNotifications exercises every possible modification +// scenario for an existing service, verifying the correct update types +// reach the correct clusters. +func TestServiceModifyNotifications(t *testing.T) { + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + chs := map[string]chan *proto.GetMappingUpdateResponse{ + "cluster-a": registerFakeProxyWithCaps(s, "proxy-a", "cluster-a", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}), + "cluster-b": registerFakeProxyWithCaps(s, "proxy-b", "cluster-b", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}), + } + return s, chs + } + + httpMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: updateType, + Id: "svc-1", + AccountId: "acct-1", + Domain: "app.example.com", + Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:8080"}}, + } + } + + tlsOnlyMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: updateType, + Id: "svc-1", + AccountId: "acct-1", + Domain: "app.example.com", + Mode: "tls", + ListenPort: 8443, + Path: []*proto.PathMapping{{Target: "10.0.0.1:443"}}, + } + } + + ctx := context.Background() + + t.Run("targets changed sends MODIFIED to same cluster", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg, "cluster-a should receive update") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.NotEmpty(t, msg.AuthToken, "MODIFIED should include token") + assert.True(t, drainEmpty(chs["cluster-b"]), "cluster-b should not receive update") + }) + + t.Run("auth config changed sends MODIFIED", func(t *testing.T) { + s, chs := newServer() + mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.Auth = &proto.Authentication{Password: true, Pin: true} + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.True(t, msg.Auth.Password) + assert.True(t, msg.Auth.Pin) + }) + + t.Run("HTTP to TLS transition sends MODIFIED with TLS config", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.Equal(t, "tls", msg.Mode, "mode should be tls") + assert.Equal(t, int32(8443), msg.ListenPort) + assert.Len(t, msg.Path, 1, "should have one path entry with target address") + assert.Equal(t, "10.0.0.1:443", msg.Path[0].Target) + }) + + t.Run("TLS to HTTP transition sends MODIFIED without TLS", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type) + assert.Empty(t, msg.Mode, "mode should be empty for HTTP") + assert.True(t, len(msg.Path) > 0) + }) + + t.Run("TLS port changed sends MODIFIED with new port", func(t *testing.T) { + s, chs := newServer() + mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.ListenPort = 9443 + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, int32(9443), msg.ListenPort) + }) + + t.Run("disable sends REMOVED to cluster", func(t *testing.T) { + s, chs := newServer() + // Manager sends Delete when service is disabled + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.Type) + assert.Empty(t, msg.AuthToken, "DELETE should not have token") + }) + + t.Run("enable sends CREATED to cluster", func(t *testing.T) { + s, chs := newServer() + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msg.Type) + assert.NotEmpty(t, msg.AuthToken) + }) + + t.Run("domain change with cluster change sends DELETE to old CREATE to new", func(t *testing.T) { + s, chs := newServer() + // This is the pattern the manager produces: + // 1. DELETE on old cluster + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + // 2. CREATE on new cluster + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-b") + + msgA := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgA, "old cluster should receive DELETE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgA.Type) + + msgB := drainMapping(chs["cluster-b"]) + require.NotNil(t, msgB, "new cluster should receive CREATE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgB.Type) + assert.NotEmpty(t, msgB.AuthToken) + }) + + t.Run("domain change same cluster sends DELETE then CREATE", func(t *testing.T) { + s, chs := newServer() + // Domain changes within same cluster: manager sends DELETE (old domain) + CREATE (new domain). + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a") + s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a") + + msgDel := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgDel, "same cluster should receive DELETE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgDel.Type) + + msgCreate := drainMapping(chs["cluster-a"]) + require.NotNil(t, msgCreate, "same cluster should receive CREATE") + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgCreate.Type) + assert.NotEmpty(t, msgCreate.AuthToken) + }) + + t.Run("TLS passthrough sent to all proxies", func(t *testing.T) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + const cluster = "proxy.example.com" + chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + // TLS passthrough works on all proxies regardless of custom port support + s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster) + + msgModern := drainMapping(chModern) + require.NotNil(t, msgModern, "modern proxy receives TLS update") + assert.Equal(t, "tls", msgModern.Mode) + + msgLegacy := drainMapping(chLegacy) + assert.NotNil(t, msgLegacy, "legacy proxy should also receive TLS passthrough") + }) + + t.Run("TLS on default port NOT filtered for legacy proxy", func(t *testing.T) { + s := &ProxyServiceServer{ + tokenStore: tokenStore, + } + s.SetProxyController(newTestProxyController()) + const cluster = "proxy.example.com" + chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.ListenPort = 0 // default port + s.SendServiceUpdateToCluster(ctx, mapping, cluster) + + msgLegacy := drainMapping(chLegacy) + assert.NotNil(t, msgLegacy, "legacy proxy should receive TLS on default port") + }) + + t.Run("passthrough and rewrite flags propagated", func(t *testing.T) { + s, chs := newServer() + mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + mapping.PassHostHeader = true + mapping.RewriteRedirects = true + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + assert.True(t, msg.PassHostHeader) + assert.True(t, msg.RewriteRedirects) + }) + + t.Run("multiple paths propagated in MODIFIED", func(t *testing.T) { + s, chs := newServer() + mapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, + Id: "svc-multi", + AccountId: "acct-1", + Domain: "multi.example.com", + Path: []*proto.PathMapping{ + {Path: "/", Target: "http://10.0.0.1:8080"}, + {Path: "/api", Target: "http://10.0.0.2:9090"}, + {Path: "/ws", Target: "http://10.0.0.3:3000"}, + }, + } + s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a") + + msg := drainMapping(chs["cluster-a"]) + require.NotNil(t, msg) + require.Len(t, msg.Path, 3, "all paths should be present") + assert.Equal(t, "/", msg.Path[0].Path) + assert.Equal(t, "/api", msg.Path[1].Path) + assert.Equal(t, "/ws", msg.Path[2].Path) + }) +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ddeda6d7f..ad36b9d46 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -174,9 +174,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks instance.AddEndpoints(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router) if serviceManager != nil && reverseProxyDomainManager != nil { - reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } - // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 462013963..6bd269a2c 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel/metric/noop" "github.com/netbirdio/management-integrations/integrations" + accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" @@ -113,6 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } + domainManager.SetClusterCapabilities(serviceProxyController) serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index bfefce388..8732cf89f 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -219,7 +219,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { servicesStatusActive int servicesStatusPending int servicesStatusError int - servicesTargetType map[string]int + servicesTargetType map[rpservice.TargetType]int servicesAuthPassword int servicesAuthPin int servicesAuthOIDC int @@ -232,7 +232,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { rulesDirection = make(map[string]int) activeUsersLastDay = make(map[string]struct{}) embeddedIdpTypes = make(map[string]int) - servicesTargetType = make(map[string]int) + servicesTargetType = make(map[rpservice.TargetType]int) uptime = time.Since(w.startupTime).Seconds() connections := w.connManager.GetAllConnectedPeers() version = nbversion.NetbirdVersion() @@ -434,7 +434,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["custom_domains_validated"] = customDomainsValidated for targetType, count := range servicesTargetType { - metricsProperties["services_target_type_"+targetType] = count + metricsProperties["services_target_type_"+string(targetType)] = count } for idpType, count := range embeddedIdpTypes { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 5997c10e2..b3fbfe141 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -30,6 +30,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" @@ -4996,6 +4997,7 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpse return service, nil } + func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { @@ -5041,16 +5043,16 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS } // RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service. -func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error { +func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error { result := s.db.Model(&rpservice.Service{}). - Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral). + Where("id = ? AND account_id = ? AND source_peer = ? AND source = ?", serviceID, accountID, peerID, rpservice.SourceEphemeral). Update("meta_last_renewed_at", time.Now()) if result.Error != nil { log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error) return status.Errorf(status.Internal, "renew ephemeral service") } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "no active expose session for domain %s", domain) + return status.Errorf(status.NotFound, "no active expose session for service %s", serviceID) } return nil } @@ -5133,6 +5135,37 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock return id != "", nil } +// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port. +func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var services []*rpservice.Service + result := tx.Where("proxy_cluster = ? AND mode = ? AND listen_port = ?", proxyCluster, mode, listenPort).Find(&services) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "query services by cluster and port") + } + + return services, nil +} + +// GetServicesByCluster returns all services for the given proxy cluster. +func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var services []*rpservice.Service + result := tx.Where("proxy_cluster = ?", proxyCluster).Find(&services) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "query services by cluster") + } + return services, nil +} + func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) { tx := s.db diff --git a/management/server/store/store.go b/management/server/store/store.go index 1fa99fd05..8bb52f38a 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -261,10 +261,12 @@ type Store interface { GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) - RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error + RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) + GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) + GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 130df4485..e75e35b94 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1991,6 +1991,36 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength) } +// GetServicesByCluster mocks base method. +func (m *MockStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByCluster", ctx, lockStrength, proxyCluster) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByCluster indicates an expected call of GetServicesByCluster. +func (mr *MockStoreMockRecorder) GetServicesByCluster(ctx, lockStrength, proxyCluster interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByCluster", reflect.TypeOf((*MockStore)(nil).GetServicesByCluster), ctx, lockStrength, proxyCluster) +} + +// GetServicesByClusterAndPort mocks base method. +func (m *MockStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster, mode string, listenPort uint16) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByClusterAndPort", ctx, lockStrength, proxyCluster, mode, listenPort) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByClusterAndPort indicates an expected call of GetServicesByClusterAndPort. +func (mr *MockStoreMockRecorder) GetServicesByClusterAndPort(ctx, lockStrength, proxyCluster, mode, listenPort interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByClusterAndPort", reflect.TypeOf((*MockStore)(nil).GetServicesByClusterAndPort), ctx, lockStrength, proxyCluster, mode, listenPort) +} + // GetSetupKeyByID mocks base method. func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) { m.ctrl.T.Helper() @@ -2447,17 +2477,17 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID } // RenewEphemeralService mocks base method. -func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error { +func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain) + ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, serviceID) ret0, _ := ret[0].(error) return ret0 } // RenewEphemeralService indicates an expected call of RenewEphemeralService. -func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, serviceID) } // RevokeProxyAccessToken mocks base method. diff --git a/management/server/types/account.go b/management/server/types/account.go index 6145ceeb2..269fc7a88 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -907,8 +907,8 @@ func (a *Account) Copy() *Account { } services := []*service.Service{} - for _, service := range a.Services { - services = append(services, service.Copy()) + for _, svc := range a.Services { + services = append(services, svc.Copy()) } return &Account{ @@ -1605,12 +1605,12 @@ func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy { networkResourceGroups := a.getNetworkResourceGroups(resourceId) for _, policy := range a.Policies { - if !policy.Enabled { + if policy == nil || !policy.Enabled { continue } for _, rule := range policy.Rules { - if !rule.Enabled { + if rule == nil || !rule.Enabled { continue } @@ -1812,15 +1812,18 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) { } a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster) } + } func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { + proxyPeers := proxyPeersByCluster[service.ProxyCluster] for _, target := range service.Targets { if !target.Enabled { continue } - a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster]) + a.injectTargetProxyPolicies(ctx, service, target, proxyPeers) } + } func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) { @@ -1840,13 +1843,13 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *servic } } -func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) { +func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (uint16, bool) { if target.Port != 0 { return target.Port, true } switch target.Protocol { - case "https": + case "https", "tls": return 443, true case "http": return 80, true @@ -1856,17 +1859,23 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) } } -func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy { - policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path) +func (a *Account) createProxyPolicy(svc *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port uint16, path string) *Policy { + policyID := fmt.Sprintf("proxy-access-%s-%s-%s", svc.ID, proxyPeer.ID, path) + + protocol := PolicyRuleProtocolTCP + if svc.Mode == service.ModeUDP { + protocol = PolicyRuleProtocolUDP + } + return &Policy{ ID: policyID, - Name: fmt.Sprintf("Proxy Access to %s", service.Name), + Name: fmt.Sprintf("Proxy Access to %s", svc.Name), Enabled: true, Rules: []*PolicyRule{ { ID: policyID, PolicyID: policyID, - Name: fmt.Sprintf("Allow access to %s", service.Name), + Name: fmt.Sprintf("Allow access to %s", svc.Name), Enabled: true, SourceResource: Resource{ ID: proxyPeer.ID, @@ -1877,12 +1886,12 @@ func (a *Account) createProxyPolicy(service *service.Service, target *service.Ta Type: ResourceType(target.TargetType), }, Bidirectional: false, - Protocol: PolicyRuleProtocolTCP, + Protocol: protocol, Action: PolicyTrafficActionAccept, PortRanges: []RulePortRange{ { - Start: uint16(port), - End: uint16(port), + Start: port, + End: port, }, }, }, diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 60e81feb5..d82f5b7fc 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -7,6 +7,7 @@ import ( "os/signal" "strconv" "syscall" + "time" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -34,30 +35,32 @@ var ( ) var ( - logLevel string - debugLogs bool - mgmtAddr string - addr string - proxyDomain string - certDir string - acmeCerts bool - acmeAddr string - acmeDir string - acmeEABKID string - acmeEABHMACKey string - acmeChallengeType string - debugEndpoint bool - debugEndpointAddr string - healthAddr string - forwardedProto string - trustedProxies string - certFile string - certKeyFile string - certLockMethod string - wildcardCertDir string - wgPort int - proxyProtocol bool - preSharedKey string + logLevel string + debugLogs bool + mgmtAddr string + addr string + proxyDomain string + defaultDialTimeout time.Duration + certDir string + acmeCerts bool + acmeAddr string + acmeDir string + acmeEABKID string + acmeEABHMACKey string + acmeChallengeType string + debugEndpoint bool + debugEndpointAddr string + healthAddr string + forwardedProto string + trustedProxies string + certFile string + certKeyFile string + certLockMethod string + wildcardCertDir string + wgPort uint16 + proxyProtocol bool + preSharedKey string + supportsCustomPorts bool ) var rootCmd = &cobra.Command{ @@ -92,9 +95,11 @@ func init() { rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory") rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") rootCmd.Flags().StringVar(&wildcardCertDir, "wildcard-cert-dir", envStringOrDefault("NB_PROXY_WILDCARD_CERT_DIR", ""), "Directory containing wildcard certificate pairs (.crt/.key). Wildcard patterns are extracted from SANs automatically") - rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") + rootCmd.Flags().Uint16Var(&wgPort, "wg-port", envUint16OrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") + rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough") + rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)") } // Execute runs the root command. @@ -171,6 +176,8 @@ func runServer(cmd *cobra.Command, args []string) error { WireguardPort: wgPort, ProxyProtocol: proxyProtocol, PreSharedKey: preSharedKey, + SupportsCustomPorts: supportsCustomPorts, + DefaultDialTimeout: defaultDialTimeout, } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) @@ -203,12 +210,24 @@ func envStringOrDefault(key string, def string) string { return v } -func envIntOrDefault(key string, def int) int { +func envUint16OrDefault(key string, def uint16) uint16 { v, exists := os.LookupEnv(key) if !exists { return def } - parsed, err := strconv.Atoi(v) + parsed, err := strconv.ParseUint(v, 10, 16) + if err != nil { + return def + } + return uint16(parsed) +} + +func envDurationOrDefault(key string, def time.Duration) time.Duration { + v, exists := os.LookupEnv(key) + if !exists { + return def + } + parsed, err := time.ParseDuration(v) if err != nil { return def } diff --git a/proxy/handle_mapping_stream_test.go b/proxy/handle_mapping_stream_test.go index d2ad3f67e..cb16c0814 100644 --- a/proxy/handle_mapping_stream_test.go +++ b/proxy/handle_mapping_stream_test.go @@ -38,11 +38,18 @@ func (m *mockMappingStream) Context() context.Context { return context.Backgroun func (m *mockMappingStream) SendMsg(any) error { return nil } func (m *mockMappingStream) RecvMsg(any) error { return nil } +func closedChan() chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} + func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) { checker := health.NewChecker(nil, nil) s := &Server{ Logger: log.StandardLogger(), healthChecker: checker, + routerReady: closedChan(), } stream := &mockMappingStream{ @@ -62,6 +69,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { s := &Server{ Logger: log.StandardLogger(), healthChecker: checker, + routerReady: closedChan(), } stream := &mockMappingStream{ @@ -78,7 +86,8 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { func TestHandleMappingStream_NilHealthChecker(t *testing.T) { s := &Server{ - Logger: log.StandardLogger(), + Logger: log.StandardLogger(), + routerReady: closedChan(), } stream := &mockMappingStream{ diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 4ba5a7755..5b05ab195 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -6,11 +6,13 @@ import ( "sync" "time" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -19,6 +21,7 @@ const ( bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours + logSendTimeout = 10 * time.Second ) type domainUsage struct { @@ -79,22 +82,63 @@ func (l *Logger) Close() { type logEntry struct { ID string - AccountID string - ServiceId string + AccountID types.AccountID + ServiceId types.ServiceID Host string Path string DurationMs int64 Method string ResponseCode int32 - SourceIp string + SourceIP netip.Addr AuthMechanism string UserId string AuthSuccess bool BytesUpload int64 BytesDownload int64 + Protocol Protocol } -func (l *Logger) log(ctx context.Context, entry logEntry) { +// Protocol identifies the transport protocol of an access log entry. +type Protocol string + +const ( + ProtocolHTTP Protocol = "http" + ProtocolTCP Protocol = "tcp" + ProtocolUDP Protocol = "udp" + ProtocolTLS Protocol = "tls" +) + +// L4Entry holds the data for a layer-4 (TCP/UDP) access log entry. +type L4Entry struct { + AccountID types.AccountID + ServiceID types.ServiceID + Protocol Protocol + Host string // SNI hostname or listen address + SourceIP netip.Addr + DurationMs int64 + BytesUpload int64 + BytesDownload int64 +} + +// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP). +// The call is non-blocking: the gRPC send happens in a background goroutine. +func (l *Logger) LogL4(entry L4Entry) { + le := logEntry{ + ID: xid.New().String(), + AccountID: entry.AccountID, + ServiceId: entry.ServiceID, + Protocol: entry.Protocol, + Host: entry.Host, + SourceIP: entry.SourceIP, + DurationMs: entry.DurationMs, + BytesUpload: entry.BytesUpload, + BytesDownload: entry.BytesDownload, + } + l.log(le) + l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload) +} + +func (l *Logger) log(entry logEntry) { // Fire off the log request in a separate routine. // This increases the possibility of losing a log message // (although it should still get logged in the event of an error), @@ -105,31 +149,37 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { // allow for resolving that on the server. now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary. go func() { - logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout) defer cancel() if entry.AuthMechanism != auth.MethodOIDC.String() { entry.UserId = "" } + + var sourceIP string + if entry.SourceIP.IsValid() { + sourceIP = entry.SourceIP.String() + } + if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{ Log: &proto.AccessLog{ LogId: entry.ID, - AccountId: entry.AccountID, + AccountId: string(entry.AccountID), Timestamp: now, - ServiceId: entry.ServiceId, + ServiceId: string(entry.ServiceId), Host: entry.Host, Path: entry.Path, DurationMs: entry.DurationMs, Method: entry.Method, ResponseCode: entry.ResponseCode, - SourceIp: entry.SourceIp, + SourceIp: sourceIP, AuthMechanism: entry.AuthMechanism, UserId: entry.UserId, AuthSuccess: entry.AuthSuccess, BytesUpload: entry.BytesUpload, BytesDownload: entry.BytesDownload, + Protocol: string(entry.Protocol), }, }); err != nil { - // If it fails to send on the gRPC connection, then at least log it to the error log. l.logger.WithFields(log.Fields{ "service_id": entry.ServiceId, "host": entry.Host, @@ -137,7 +187,7 @@ func (l *Logger) log(ctx context.Context, entry logEntry) { "duration": entry.DurationMs, "method": entry.Method, "response_code": entry.ResponseCode, - "source_ip": entry.SourceIp, + "source_ip": sourceIP, "auth_mechanism": entry.AuthMechanism, "user_id": entry.UserId, "auth_success": entry.AuthSuccess, diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 7368185c0..593a77ef2 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -67,23 +67,24 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { entry := logEntry{ ID: requestID, ServiceId: capturedData.GetServiceId(), - AccountID: string(capturedData.GetAccountId()), + AccountID: capturedData.GetAccountId(), Host: host, Path: r.URL.Path, DurationMs: duration.Milliseconds(), Method: r.Method, ResponseCode: int32(sw.status), - SourceIp: sourceIp, + SourceIP: sourceIp, AuthMechanism: capturedData.GetAuthMethod(), UserId: capturedData.GetUserID(), AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, BytesUpload: bytesUpload, BytesDownload: bytesDownload, + Protocol: ProtocolHTTP, } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId()) - l.log(r.Context(), entry) + l.log(entry) // Track usage for cost monitoring (upload + download) by domain l.trackUsage(host, bytesUpload+bytesDownload) diff --git a/proxy/internal/accesslog/requestip.go b/proxy/internal/accesslog/requestip.go index f111c1322..30c483fd9 100644 --- a/proxy/internal/accesslog/requestip.go +++ b/proxy/internal/accesslog/requestip.go @@ -11,6 +11,6 @@ import ( // proxy configuration. When trustedProxies is non-empty and the direct // connection is from a trusted source, it walks X-Forwarded-For right-to-left // skipping trusted IPs. Otherwise it returns RemoteAddr directly. -func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string { +func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) netip.Addr { return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies) } diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index 395da7d88..a4a220ed7 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -23,6 +23,7 @@ import ( "golang.org/x/crypto/acme/autocert" "github.com/netbirdio/netbird/proxy/internal/certwatch" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -30,7 +31,7 @@ import ( var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2} type certificateNotifier interface { - NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error + NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error } type domainState int @@ -42,8 +43,8 @@ const ( ) type domainInfo struct { - accountID string - serviceID string + accountID types.AccountID + serviceID types.ServiceID state domainState err string } @@ -301,7 +302,7 @@ func (mgr *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate // When AddDomain returns true the caller is responsible for sending any // certificate-ready notifications after the surrounding operation (e.g. // mapping update) has committed successfully. -func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) (wildcardHit bool) { +func (mgr *Manager) AddDomain(d domain.Domain, accountID types.AccountID, serviceID types.ServiceID) (wildcardHit bool) { name := d.PunycodeString() if e := mgr.findWildcardEntry(name); e != nil { mgr.mu.Lock() diff --git a/proxy/internal/acme/manager_test.go b/proxy/internal/acme/manager_test.go index 9a3ed9efd..ceb9ca13a 100644 --- a/proxy/internal/acme/manager_test.go +++ b/proxy/internal/acme/manager_test.go @@ -17,12 +17,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" ) func TestHostPolicy(t *testing.T) { mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil) require.NoError(t, err) - mgr.AddDomain("example.com", "acc1", "rp1") + mgr.AddDomain("example.com", types.AccountID("acc1"), types.ServiceID("rp1")) // Wait for the background prefetch goroutine to finish so the temp dir // can be cleaned up without a race. @@ -92,8 +94,8 @@ func TestDomainStates(t *testing.T) { // AddDomain starts as pending, then the prefetch goroutine will fail // (no real ACME server) and transition to failed. - mgr.AddDomain("a.example.com", "acc1", "rp1") - mgr.AddDomain("b.example.com", "acc1", "rp1") + mgr.AddDomain("a.example.com", types.AccountID("acc1"), types.ServiceID("rp1")) + mgr.AddDomain("b.example.com", types.AccountID("acc1"), types.ServiceID("rp1")) assert.Equal(t, 2, mgr.TotalDomains(), "two domains registered") @@ -209,12 +211,12 @@ func TestWildcardAddDomainSkipsACME(t *testing.T) { require.NoError(t, err) // Add a wildcard-matching domain — should be immediately ready. - mgr.AddDomain("foo.example.com", "acc1", "svc1") + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) assert.Equal(t, 0, mgr.PendingCerts(), "wildcard domain should not be pending") assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains()) // Add a non-wildcard domain — should go through ACME (pending then failed). - mgr.AddDomain("other.net", "acc2", "svc2") + mgr.AddDomain("other.net", types.AccountID("acc2"), types.ServiceID("svc2")) assert.Equal(t, 2, mgr.TotalDomains()) // Wait for the ACME prefetch to fail. @@ -234,7 +236,7 @@ func TestWildcardGetCertificate(t *testing.T) { mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil) require.NoError(t, err) - mgr.AddDomain("foo.example.com", "acc1", "svc1") + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) // GetCertificate for a wildcard-matching domain should return the static cert. cert, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) @@ -255,8 +257,8 @@ func TestMultipleWildcards(t *testing.T) { assert.ElementsMatch(t, []string{"*.example.com", "*.other.org"}, mgr.WildcardPatterns()) // Both wildcards should resolve. - mgr.AddDomain("foo.example.com", "acc1", "svc1") - mgr.AddDomain("bar.other.org", "acc2", "svc2") + mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1")) + mgr.AddDomain("bar.other.org", types.AccountID("acc2"), types.ServiceID("svc2")) assert.Equal(t, 0, mgr.PendingCerts()) assert.ElementsMatch(t, []string{"foo.example.com", "bar.other.org"}, mgr.ReadyDomains()) @@ -271,7 +273,7 @@ func TestMultipleWildcards(t *testing.T) { assert.Contains(t, cert2.Leaf.DNSNames, "*.other.org") // Non-matching domain falls through to ACME. - mgr.AddDomain("custom.net", "acc3", "svc3") + mgr.AddDomain("custom.net", types.AccountID("acc3"), types.ServiceID("svc3")) assert.Eventually(t, func() bool { return mgr.PendingCerts() == 0 }, 30*time.Second, 100*time.Millisecond) diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 8a966faa3..3cf86e4b3 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -44,8 +44,8 @@ type DomainConfig struct { Schemes []Scheme SessionPublicKey ed25519.PublicKey SessionExpiration time.Duration - AccountID string - ServiceID string + AccountID types.AccountID + ServiceID types.ServiceID } type validationResult struct { @@ -124,7 +124,7 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) { func setCapturedIDs(r *http.Request, config DomainConfig) { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetAccountId(types.AccountID(config.AccountID)) + cd.SetAccountId(config.AccountID) cd.SetServiceId(config.ServiceID) } } @@ -275,7 +275,7 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool { // session JWTs. Returns an error if the key is missing or invalid. // Callers must not serve the domain if this returns an error, to avoid // exposing an unauthenticated service. -func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error { +func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error { if len(schemes) == 0 { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() diff --git a/proxy/internal/auth/oidc.go b/proxy/internal/auth/oidc.go index bf178d432..a60e6437a 100644 --- a/proxy/internal/auth/oidc.go +++ b/proxy/internal/auth/oidc.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -17,14 +18,14 @@ type urlGenerator interface { } type OIDC struct { - id string - accountId string + id types.ServiceID + accountId types.AccountID forwardedProto string client urlGenerator } // NewOIDC creates a new OIDC authentication scheme -func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC { +func NewOIDC(client urlGenerator, id types.ServiceID, accountId types.AccountID, forwardedProto string) OIDC { return OIDC{ id: id, accountId: accountId, @@ -53,8 +54,8 @@ func (o OIDC) Authenticate(r *http.Request) (string, string, error) { } res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{ - Id: o.id, - AccountId: o.accountId, + Id: string(o.id), + AccountId: string(o.accountId), RedirectUrl: redirectURL.String(), }) if err != nil { diff --git a/proxy/internal/auth/password.go b/proxy/internal/auth/password.go index 208423465..6a7eda3e1 100644 --- a/proxy/internal/auth/password.go +++ b/proxy/internal/auth/password.go @@ -5,17 +5,19 @@ import ( "net/http" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) const passwordFormId = "password" type Password struct { - id, accountId string - client authenticator + id types.ServiceID + accountId types.AccountID + client authenticator } -func NewPassword(client authenticator, id, accountId string) Password { +func NewPassword(client authenticator, id types.ServiceID, accountId types.AccountID) Password { return Password{ id: id, accountId: accountId, @@ -41,8 +43,8 @@ func (p Password) Authenticate(r *http.Request) (string, string, error) { } res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ - Id: p.id, - AccountId: p.accountId, + Id: string(p.id), + AccountId: string(p.accountId), Request: &proto.AuthenticateRequest_Password{ Password: &proto.PasswordRequest{ Password: password, diff --git a/proxy/internal/auth/pin.go b/proxy/internal/auth/pin.go index c1eb56071..4d08f3dc6 100644 --- a/proxy/internal/auth/pin.go +++ b/proxy/internal/auth/pin.go @@ -5,17 +5,19 @@ import ( "net/http" "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) const pinFormId = "pin" type Pin struct { - id, accountId string - client authenticator + id types.ServiceID + accountId types.AccountID + client authenticator } -func NewPin(client authenticator, id, accountId string) Pin { +func NewPin(client authenticator, id types.ServiceID, accountId types.AccountID) Pin { return Pin{ id: id, accountId: accountId, @@ -41,8 +43,8 @@ func (p Pin) Authenticate(r *http.Request) (string, string, error) { } res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ - Id: p.id, - AccountId: p.accountId, + Id: string(p.id), + AccountId: string(p.accountId), Request: &proto.AuthenticateRequest_Pin{ Pin: &proto.PinRequest{ Pin: pin, diff --git a/proxy/internal/conntrack/conn.go b/proxy/internal/conntrack/conn.go index 97055d992..8446d638f 100644 --- a/proxy/internal/conntrack/conn.go +++ b/proxy/internal/conntrack/conn.go @@ -10,10 +10,11 @@ import ( type trackedConn struct { net.Conn tracker *HijackTracker + host string } func (c *trackedConn) Close() error { - c.tracker.conns.Delete(c) + c.tracker.remove(c) return c.Conn.Close() } @@ -22,6 +23,7 @@ func (c *trackedConn) Close() error { type trackingWriter struct { http.ResponseWriter tracker *HijackTracker + host string } func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { @@ -33,8 +35,8 @@ func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if err != nil { return nil, nil, err } - tc := &trackedConn{Conn: conn, tracker: w.tracker} - w.tracker.conns.Store(tc, struct{}{}) + tc := &trackedConn{Conn: conn, tracker: w.tracker, host: w.host} + w.tracker.add(tc) return tc, buf, nil } diff --git a/proxy/internal/conntrack/hijacked.go b/proxy/internal/conntrack/hijacked.go index d76cebc08..911f93f3d 100644 --- a/proxy/internal/conntrack/hijacked.go +++ b/proxy/internal/conntrack/hijacked.go @@ -1,7 +1,6 @@ package conntrack import ( - "net" "net/http" "sync" ) @@ -10,10 +9,14 @@ import ( // upgrades). http.Server.Shutdown does not close hijacked connections, so // they must be tracked and closed explicitly during graceful shutdown. // +// Connections are indexed by the request Host so they can be closed +// per-domain when a service mapping is removed. +// // Use Middleware as the outermost HTTP middleware to ensure hijacked // connections are tracked and automatically deregistered when closed. type HijackTracker struct { - conns sync.Map // net.Conn → struct{} + mu sync.Mutex + conns map[*trackedConn]struct{} } // Middleware returns an HTTP middleware that wraps the ResponseWriter so that @@ -21,21 +24,73 @@ type HijackTracker struct { // tracker when closed. This should be the outermost middleware in the chain. func (t *HijackTracker) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r) + next.ServeHTTP(&trackingWriter{ + ResponseWriter: w, + tracker: t, + host: hostOnly(r.Host), + }, r) }) } -// CloseAll closes all tracked hijacked connections and returns the number -// of connections that were closed. +// CloseAll closes all tracked hijacked connections and returns the count. func (t *HijackTracker) CloseAll() int { - var count int - t.conns.Range(func(key, _ any) bool { - if conn, ok := key.(net.Conn); ok { - _ = conn.Close() - count++ - } - t.conns.Delete(key) - return true - }) - return count + t.mu.Lock() + conns := t.conns + t.conns = nil + t.mu.Unlock() + + for tc := range conns { + _ = tc.Conn.Close() + } + return len(conns) +} + +// CloseByHost closes all tracked hijacked connections for the given host +// and returns the number of connections closed. +func (t *HijackTracker) CloseByHost(host string) int { + host = hostOnly(host) + t.mu.Lock() + var toClose []*trackedConn + for tc := range t.conns { + if tc.host == host { + toClose = append(toClose, tc) + } + } + for _, tc := range toClose { + delete(t.conns, tc) + } + t.mu.Unlock() + + for _, tc := range toClose { + _ = tc.Conn.Close() + } + return len(toClose) +} + +func (t *HijackTracker) add(tc *trackedConn) { + t.mu.Lock() + if t.conns == nil { + t.conns = make(map[*trackedConn]struct{}) + } + t.conns[tc] = struct{}{} + t.mu.Unlock() +} + +func (t *HijackTracker) remove(tc *trackedConn) { + t.mu.Lock() + delete(t.conns, tc) + t.mu.Unlock() +} + +// hostOnly strips the port from a host:port string. +func hostOnly(hostport string) string { + for i := len(hostport) - 1; i >= 0; i-- { + if hostport[i] == ':' { + return hostport[:i] + } + if hostport[i] < '0' || hostport[i] > '9' { + return hostport + } + } + return hostport } diff --git a/proxy/internal/conntrack/hijacked_test.go b/proxy/internal/conntrack/hijacked_test.go new file mode 100644 index 000000000..9ceefff78 --- /dev/null +++ b/proxy/internal/conntrack/hijacked_test.go @@ -0,0 +1,142 @@ +package conntrack + +import ( + "bufio" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeHijackWriter implements http.ResponseWriter and http.Hijacker for testing. +type fakeHijackWriter struct { + http.ResponseWriter + conn net.Conn +} + +func (f *fakeHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) + return f.conn, rw, nil +} + +func TestCloseByHost(t *testing.T) { + var tracker HijackTracker + + // Simulate hijacking two connections for different hosts. + connA1, connA2 := net.Pipe() + defer connA2.Close() + connB1, connB2 := net.Pipe() + defer connB2.Close() + + twA := &trackingWriter{ + ResponseWriter: httptest.NewRecorder(), + tracker: &tracker, + host: "a.example.com", + } + twB := &trackingWriter{ + ResponseWriter: httptest.NewRecorder(), + tracker: &tracker, + host: "b.example.com", + } + + // Use fakeHijackWriter to provide the Hijack method. + twA.ResponseWriter = &fakeHijackWriter{ResponseWriter: twA.ResponseWriter, conn: connA1} + twB.ResponseWriter = &fakeHijackWriter{ResponseWriter: twB.ResponseWriter, conn: connB1} + + _, _, err := twA.Hijack() + require.NoError(t, err) + _, _, err = twB.Hijack() + require.NoError(t, err) + + tracker.mu.Lock() + assert.Equal(t, 2, len(tracker.conns), "should track 2 connections") + tracker.mu.Unlock() + + // Close only host A. + n := tracker.CloseByHost("a.example.com") + assert.Equal(t, 1, n, "should close 1 connection for host A") + + tracker.mu.Lock() + assert.Equal(t, 1, len(tracker.conns), "should have 1 remaining connection") + tracker.mu.Unlock() + + // Verify host A's conn is actually closed. + buf := make([]byte, 1) + _, err = connA2.Read(buf) + assert.Error(t, err, "host A pipe should be closed") + + // Host B should still be alive. + go func() { _, _ = connB1.Write([]byte("x")) }() + + // Close all remaining. + n = tracker.CloseAll() + assert.Equal(t, 1, n, "should close remaining 1 connection") + + tracker.mu.Lock() + assert.Equal(t, 0, len(tracker.conns), "should have 0 connections after CloseAll") + tracker.mu.Unlock() +} + +func TestCloseAll(t *testing.T) { + var tracker HijackTracker + + for range 5 { + c1, c2 := net.Pipe() + defer c2.Close() + tc := &trackedConn{Conn: c1, tracker: &tracker, host: "test.com"} + tracker.add(tc) + } + + tracker.mu.Lock() + assert.Equal(t, 5, len(tracker.conns)) + tracker.mu.Unlock() + + n := tracker.CloseAll() + assert.Equal(t, 5, n) + + // Double CloseAll is safe. + n = tracker.CloseAll() + assert.Equal(t, 0, n) +} + +func TestTrackedConn_AutoDeregister(t *testing.T) { + var tracker HijackTracker + + c1, c2 := net.Pipe() + defer c2.Close() + + tc := &trackedConn{Conn: c1, tracker: &tracker, host: "auto.com"} + tracker.add(tc) + + tracker.mu.Lock() + assert.Equal(t, 1, len(tracker.conns)) + tracker.mu.Unlock() + + // Close the tracked conn: should auto-deregister. + require.NoError(t, tc.Close()) + + tracker.mu.Lock() + assert.Equal(t, 0, len(tracker.conns), "should auto-deregister on close") + tracker.mu.Unlock() +} + +func TestHostOnly(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.com:443", "example.com"}, + {"example.com", "example.com"}, + {"127.0.0.1:8080", "127.0.0.1"}, + {"[::1]:443", "[::1]"}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.want, hostOnly(tt.input)) + }) + } +} diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 885c574bc..01b0bc8e6 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -152,7 +152,7 @@ func (c *Client) printClients(data map[string]any) { return } - _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT") + _, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "SERVICES", "HAS CLIENT") _, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110)) for _, item := range clients { @@ -166,7 +166,7 @@ func (c *Client) printClientRow(item any) { return } - domains := c.extractDomains(client) + services := c.extractServiceKeys(client) hasClient := "no" if hc, ok := client["has_client"].(bool); ok && hc { hasClient = "yes" @@ -175,20 +175,20 @@ func (c *Client) printClientRow(item any) { _, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n", client["account_id"], client["age"], - domains, + services, hasClient, ) } -func (c *Client) extractDomains(client map[string]any) string { - d, ok := client["domains"].([]any) +func (c *Client) extractServiceKeys(client map[string]any) string { + d, ok := client["service_keys"].([]any) if !ok || len(d) == 0 { return "-" } parts := make([]string, len(d)) - for i, domain := range d { - parts[i] = fmt.Sprint(domain) + for i, key := range d { + parts[i] = fmt.Sprint(key) } return strings.Join(parts, ", ") } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index ab75c8b72..237010922 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -189,7 +189,7 @@ type indexData struct { Version string Uptime string ClientCount int - TotalDomains int + TotalServices int CertsTotal int CertsReady int CertsPending int @@ -202,7 +202,7 @@ type indexData struct { type clientData struct { AccountID string - Domains string + Services string Age string Status string } @@ -211,9 +211,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b clients := h.provider.ListClientsForDebug() sortedIDs := sortedAccountIDs(clients) - totalDomains := 0 + totalServices := 0 for _, info := range clients { - totalDomains += info.DomainCount + totalServices += info.ServiceCount } var certsTotal, certsReady, certsPending, certsFailed int @@ -234,24 +234,24 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b for _, id := range sortedIDs { info := clients[id] clientsJSON = append(clientsJSON, map[string]interface{}{ - "account_id": info.AccountID, - "domain_count": info.DomainCount, - "domains": info.Domains, - "has_client": info.HasClient, - "created_at": info.CreatedAt, - "age": time.Since(info.CreatedAt).Round(time.Second).String(), + "account_id": info.AccountID, + "service_count": info.ServiceCount, + "service_keys": info.ServiceKeys, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } resp := map[string]interface{}{ - "version": version.NetbirdVersion(), - "uptime": time.Since(h.startTime).Round(time.Second).String(), - "client_count": len(clients), - "total_domains": totalDomains, - "certs_total": certsTotal, - "certs_ready": certsReady, - "certs_pending": certsPending, - "certs_failed": certsFailed, - "clients": clientsJSON, + "version": version.NetbirdVersion(), + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "client_count": len(clients), + "total_services": totalServices, + "certs_total": certsTotal, + "certs_ready": certsReady, + "certs_pending": certsPending, + "certs_failed": certsFailed, + "clients": clientsJSON, } if len(certsPendingDomains) > 0 { resp["certs_pending_domains"] = certsPendingDomains @@ -278,7 +278,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b Version: version.NetbirdVersion(), Uptime: time.Since(h.startTime).Round(time.Second).String(), ClientCount: len(clients), - TotalDomains: totalDomains, + TotalServices: totalServices, CertsTotal: certsTotal, CertsReady: certsReady, CertsPending: certsPending, @@ -291,9 +291,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b for _, id := range sortedIDs { info := clients[id] - domains := info.Domains.SafeString() - if domains == "" { - domains = "-" + services := strings.Join(info.ServiceKeys, ", ") + if services == "" { + services = "-" } status := "No client" if info.HasClient { @@ -301,7 +301,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b } data.Clients = append(data.Clients, clientData{ AccountID: string(info.AccountID), - Domains: domains, + Services: services, Age: time.Since(info.CreatedAt).Round(time.Second).String(), Status: status, }) @@ -324,12 +324,12 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want for _, id := range sortedIDs { info := clients[id] clientsJSON = append(clientsJSON, map[string]interface{}{ - "account_id": info.AccountID, - "domain_count": info.DomainCount, - "domains": info.Domains, - "has_client": info.HasClient, - "created_at": info.CreatedAt, - "age": time.Since(info.CreatedAt).Round(time.Second).String(), + "account_id": info.AccountID, + "service_count": info.ServiceCount, + "service_keys": info.ServiceKeys, + "has_client": info.HasClient, + "created_at": info.CreatedAt, + "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } h.writeJSON(w, map[string]interface{}{ @@ -347,9 +347,9 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want for _, id := range sortedIDs { info := clients[id] - domains := info.Domains.SafeString() - if domains == "" { - domains = "-" + services := strings.Join(info.ServiceKeys, ", ") + if services == "" { + services = "-" } status := "No client" if info.HasClient { @@ -357,7 +357,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want } data.Clients = append(data.Clients, clientData{ AccountID: string(info.AccountID), - Domains: domains, + Services: services, Age: time.Since(info.CreatedAt).Round(time.Second).String(), Status: status, }) diff --git a/proxy/internal/debug/templates/clients.html b/proxy/internal/debug/templates/clients.html index 4d455b2bb..bfc25f95a 100644 --- a/proxy/internal/debug/templates/clients.html +++ b/proxy/internal/debug/templates/clients.html @@ -12,14 +12,14 @@ - + {{range .Clients}} - + diff --git a/proxy/internal/debug/templates/index.html b/proxy/internal/debug/templates/index.html index 16ab3d979..5bd25adfc 100644 --- a/proxy/internal/debug/templates/index.html +++ b/proxy/internal/debug/templates/index.html @@ -27,19 +27,19 @@
    {{range .CertsFailedDomains}}
  • {{.Domain}}: {{.Error}}
  • {{end}}
{{end}} -

Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})

+

Clients ({{.ClientCount}}) | Services ({{.TotalServices}})

{{if .Clients}}
Account IDDomainsServices Age Status
{{.AccountID}}{{.Domains}}{{.Services}} {{.Age}} {{.Status}}
- + {{range .Clients}} - + diff --git a/proxy/internal/metrics/l4_metrics_test.go b/proxy/internal/metrics/l4_metrics_test.go new file mode 100644 index 000000000..055158828 --- /dev/null +++ b/proxy/internal/metrics/l4_metrics_test.go @@ -0,0 +1,69 @@ +package metrics_test + +import ( + "context" + "reflect" + "testing" + "time" + + promexporter "go.opentelemetry.io/otel/exporters/prometheus" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + + "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func newTestMetrics(t *testing.T) *metrics.Metrics { + t.Helper() + + exporter, err := promexporter.New() + if err != nil { + t.Fatalf("create prometheus exporter: %v", err) + } + + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(exporter)) + pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath() + meter := provider.Meter(pkg) + + m, err := metrics.New(context.Background(), meter) + if err != nil { + t.Fatalf("create metrics: %v", err) + } + return m +} + +func TestL4ServiceGauge(t *testing.T) { + m := newTestMetrics(t) + + m.L4ServiceAdded(types.ServiceModeTCP) + m.L4ServiceAdded(types.ServiceModeTCP) + m.L4ServiceAdded(types.ServiceModeUDP) + m.L4ServiceRemoved(types.ServiceModeTCP) +} + +func TestTCPRelayMetrics(t *testing.T) { + m := newTestMetrics(t) + + acct := types.AccountID("acct-1") + + m.TCPRelayStarted(acct) + m.TCPRelayStarted(acct) + m.TCPRelayEnded(acct, 10*time.Second, 1000, 500) + m.TCPRelayDialError(acct) + m.TCPRelayRejected(acct) +} + +func TestUDPSessionMetrics(t *testing.T) { + m := newTestMetrics(t) + + acct := types.AccountID("acct-2") + + m.UDPSessionStarted(acct) + m.UDPSessionStarted(acct) + m.UDPSessionEnded(acct) + m.UDPSessionDialError(acct) + m.UDPSessionRejected(acct) + m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 100) + m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 200) + m.UDPPacketRelayed(types.RelayDirectionBackendToClient, 150) +} diff --git a/proxy/internal/metrics/metrics.go b/proxy/internal/metrics/metrics.go index 68ff55fe5..573485625 100644 --- a/proxy/internal/metrics/metrics.go +++ b/proxy/internal/metrics/metrics.go @@ -6,12 +6,15 @@ import ( "sync" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/responsewriter" + "github.com/netbirdio/netbird/proxy/internal/types" ) +// Metrics collects OpenTelemetry metrics for the proxy. type Metrics struct { ctx context.Context requestsTotal metric.Int64Counter @@ -22,85 +25,188 @@ type Metrics struct { backendDuration metric.Int64Histogram certificateIssueDuration metric.Int64Histogram + // L4 service-level metrics. + l4Services metric.Int64UpDownCounter + + // L4 TCP connection-level metrics. + tcpActiveConns metric.Int64UpDownCounter + tcpConnsTotal metric.Int64Counter + tcpConnDuration metric.Int64Histogram + tcpBytesTotal metric.Int64Counter + + // L4 UDP session-level metrics. + udpActiveSess metric.Int64UpDownCounter + udpSessionsTotal metric.Int64Counter + udpPacketsTotal metric.Int64Counter + udpBytesTotal metric.Int64Counter + mappingsMux sync.Mutex mappingPaths map[string]int } +// New creates a Metrics instance using the given OpenTelemetry meter. func New(ctx context.Context, meter metric.Meter) (*Metrics, error) { - requestsTotal, err := meter.Int64Counter( + m := &Metrics{ + ctx: ctx, + mappingPaths: make(map[string]int), + } + + if err := m.initHTTPMetrics(meter); err != nil { + return nil, err + } + if err := m.initL4Metrics(meter); err != nil { + return nil, err + } + + return m, nil +} + +func (m *Metrics) initHTTPMetrics(meter metric.Meter) error { + var err error + + m.requestsTotal, err = meter.Int64Counter( "proxy.http.request.counter", metric.WithUnit("1"), metric.WithDescription("Total number of requests made to the netbird proxy"), ) if err != nil { - return nil, err + return err } - activeRequests, err := meter.Int64UpDownCounter( + m.activeRequests, err = meter.Int64UpDownCounter( "proxy.http.active_requests", metric.WithUnit("1"), metric.WithDescription("Current in-flight requests handled by the netbird proxy"), ) if err != nil { - return nil, err + return err } - configuredDomains, err := meter.Int64UpDownCounter( + m.configuredDomains, err = meter.Int64UpDownCounter( "proxy.domains.count", metric.WithUnit("1"), metric.WithDescription("Current number of domains configured on the netbird proxy"), ) if err != nil { - return nil, err + return err } - totalPaths, err := meter.Int64UpDownCounter( + m.totalPaths, err = meter.Int64UpDownCounter( "proxy.paths.count", metric.WithUnit("1"), metric.WithDescription("Total number of paths configured on the netbird proxy"), ) if err != nil { - return nil, err + return err } - requestDuration, err := meter.Int64Histogram( + m.requestDuration, err = meter.Int64Histogram( "proxy.http.request.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of requests made to the netbird proxy"), ) if err != nil { - return nil, err + return err } - backendDuration, err := meter.Int64Histogram( + m.backendDuration, err = meter.Int64Histogram( "proxy.backend.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of peer round trip time from the netbird proxy"), ) if err != nil { - return nil, err + return err } - certificateIssueDuration, err := meter.Int64Histogram( + m.certificateIssueDuration, err = meter.Int64Histogram( "proxy.certificate.issue.duration.ms", metric.WithUnit("milliseconds"), metric.WithDescription("Duration of ACME certificate issuance"), ) + return err +} + +func (m *Metrics) initL4Metrics(meter metric.Meter) error { + var err error + + m.l4Services, err = meter.Int64UpDownCounter( + "proxy.l4.services.count", + metric.WithUnit("1"), + metric.WithDescription("Current number of configured L4 services (TCP/TLS/UDP) by mode"), + ) if err != nil { - return nil, err + return err } - return &Metrics{ - ctx: ctx, - requestsTotal: requestsTotal, - activeRequests: activeRequests, - configuredDomains: configuredDomains, - totalPaths: totalPaths, - requestDuration: requestDuration, - backendDuration: backendDuration, - certificateIssueDuration: certificateIssueDuration, - mappingPaths: make(map[string]int), - }, nil + m.tcpActiveConns, err = meter.Int64UpDownCounter( + "proxy.tcp.active_connections", + metric.WithUnit("1"), + metric.WithDescription("Current number of active TCP/TLS relay connections"), + ) + if err != nil { + return err + } + + m.tcpConnsTotal, err = meter.Int64Counter( + "proxy.tcp.connections.total", + metric.WithUnit("1"), + metric.WithDescription("Total TCP/TLS relay connections by result and account"), + ) + if err != nil { + return err + } + + m.tcpConnDuration, err = meter.Int64Histogram( + "proxy.tcp.connection.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration of TCP/TLS relay connections"), + ) + if err != nil { + return err + } + + m.tcpBytesTotal, err = meter.Int64Counter( + "proxy.tcp.bytes.total", + metric.WithUnit("bytes"), + metric.WithDescription("Total bytes transferred through TCP/TLS relay by direction"), + ) + if err != nil { + return err + } + + m.udpActiveSess, err = meter.Int64UpDownCounter( + "proxy.udp.active_sessions", + metric.WithUnit("1"), + metric.WithDescription("Current number of active UDP relay sessions"), + ) + if err != nil { + return err + } + + m.udpSessionsTotal, err = meter.Int64Counter( + "proxy.udp.sessions.total", + metric.WithUnit("1"), + metric.WithDescription("Total UDP relay sessions by result and account"), + ) + if err != nil { + return err + } + + m.udpPacketsTotal, err = meter.Int64Counter( + "proxy.udp.packets.total", + metric.WithUnit("1"), + metric.WithDescription("Total UDP packets relayed by direction"), + ) + if err != nil { + return err + } + + m.udpBytesTotal, err = meter.Int64Counter( + "proxy.udp.bytes.total", + metric.WithUnit("bytes"), + metric.WithDescription("Total bytes transferred through UDP relay by direction"), + ) + return err } type responseInterceptor struct { @@ -120,6 +226,13 @@ func (w *responseInterceptor) Write(b []byte) (int, error) { return size, err } +// Unwrap returns the underlying ResponseWriter so http.ResponseController +// can reach through to the original writer for Hijack/Flush operations. +func (w *responseInterceptor) Unwrap() http.ResponseWriter { + return w.PassthroughWriter +} + +// Middleware wraps an HTTP handler with request metrics. func (m *Metrics) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.requestsTotal.Add(m.ctx, 1) @@ -144,6 +257,7 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } +// RoundTripper wraps an http.RoundTripper with backend duration metrics. func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { start := time.Now() @@ -156,6 +270,7 @@ func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { }) } +// AddMapping records that a domain mapping was added. func (m *Metrics) AddMapping(mapping proxy.Mapping) { m.mappingsMux.Lock() defer m.mappingsMux.Unlock() @@ -175,13 +290,13 @@ func (m *Metrics) AddMapping(mapping proxy.Mapping) { m.mappingPaths[mapping.Host] = newPathCount } +// RemoveMapping records that a domain mapping was removed. func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { m.mappingsMux.Lock() defer m.mappingsMux.Unlock() oldPathCount, exists := m.mappingPaths[mapping.Host] if !exists { - // Nothing to remove return } @@ -195,3 +310,80 @@ func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { func (m *Metrics) RecordCertificateIssuance(duration time.Duration) { m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds()) } + +// L4ServiceAdded increments the L4 service gauge for the given mode. +func (m *Metrics) L4ServiceAdded(mode types.ServiceMode) { + m.l4Services.Add(m.ctx, 1, metric.WithAttributes(attribute.String("mode", string(mode)))) +} + +// L4ServiceRemoved decrements the L4 service gauge for the given mode. +func (m *Metrics) L4ServiceRemoved(mode types.ServiceMode) { + m.l4Services.Add(m.ctx, -1, metric.WithAttributes(attribute.String("mode", string(mode)))) +} + +// TCPRelayStarted records a new TCP relay connection starting. +func (m *Metrics) TCPRelayStarted(accountID types.AccountID) { + acct := attribute.String("account_id", string(accountID)) + m.tcpActiveConns.Add(m.ctx, 1, metric.WithAttributes(acct)) + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success"))) +} + +// TCPRelayEnded records a TCP relay connection ending and accumulates bytes and duration. +func (m *Metrics) TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) { + acct := attribute.String("account_id", string(accountID)) + m.tcpActiveConns.Add(m.ctx, -1, metric.WithAttributes(acct)) + m.tcpConnDuration.Record(m.ctx, duration.Milliseconds(), metric.WithAttributes(acct)) + m.tcpBytesTotal.Add(m.ctx, srcToDst, metric.WithAttributes(attribute.String("direction", "client_to_backend"))) + m.tcpBytesTotal.Add(m.ctx, dstToSrc, metric.WithAttributes(attribute.String("direction", "backend_to_client"))) +} + +// TCPRelayDialError records a dial failure for a TCP relay. +func (m *Metrics) TCPRelayDialError(accountID types.AccountID) { + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "dial_error"), + )) +} + +// TCPRelayRejected records a rejected TCP relay (semaphore full). +func (m *Metrics) TCPRelayRejected(accountID types.AccountID) { + m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "rejected"), + )) +} + +// UDPSessionStarted records a new UDP session starting. +func (m *Metrics) UDPSessionStarted(accountID types.AccountID) { + acct := attribute.String("account_id", string(accountID)) + m.udpActiveSess.Add(m.ctx, 1, metric.WithAttributes(acct)) + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success"))) +} + +// UDPSessionEnded records a UDP session ending. +func (m *Metrics) UDPSessionEnded(accountID types.AccountID) { + m.udpActiveSess.Add(m.ctx, -1, metric.WithAttributes(attribute.String("account_id", string(accountID)))) +} + +// UDPSessionDialError records a dial failure for a UDP session. +func (m *Metrics) UDPSessionDialError(accountID types.AccountID) { + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "dial_error"), + )) +} + +// UDPSessionRejected records a rejected UDP session (limit or rate limited). +func (m *Metrics) UDPSessionRejected(accountID types.AccountID) { + m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes( + attribute.String("account_id", string(accountID)), + attribute.String("result", "rejected"), + )) +} + +// UDPPacketRelayed records a packet relayed in the given direction with its size in bytes. +func (m *Metrics) UDPPacketRelayed(direction types.RelayDirection, bytes int) { + dir := attribute.String("direction", string(direction)) + m.udpPacketsTotal.Add(m.ctx, 1, metric.WithAttributes(dir)) + m.udpBytesTotal.Add(m.ctx, int64(bytes), metric.WithAttributes(dir)) +} diff --git a/proxy/internal/netutil/errors.go b/proxy/internal/netutil/errors.go new file mode 100644 index 000000000..ff24e33d4 --- /dev/null +++ b/proxy/internal/netutil/errors.go @@ -0,0 +1,40 @@ +package netutil + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "net" + "syscall" +) + +// ValidatePort converts an int32 proto port to uint16, returning an error +// if the value is out of the valid 1–65535 range. +func ValidatePort(port int32) (uint16, error) { + if port <= 0 || port > math.MaxUint16 { + return 0, fmt.Errorf("invalid port %d: must be 1–65535", port) + } + return uint16(port), nil +} + +// IsExpectedError returns true for errors that are normal during +// connection teardown and should not be logged as warnings. +func IsExpectedError(err error) bool { + return errors.Is(err, net.ErrClosed) || + errors.Is(err, context.Canceled) || + errors.Is(err, io.EOF) || + errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ECONNABORTED) +} + +// IsTimeout checks whether the error is a network timeout. +func IsTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + return false +} diff --git a/proxy/internal/netutil/errors_test.go b/proxy/internal/netutil/errors_test.go new file mode 100644 index 000000000..7d6be10ff --- /dev/null +++ b/proxy/internal/netutil/errors_test.go @@ -0,0 +1,92 @@ +package netutil + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidatePort(t *testing.T) { + tests := []struct { + name string + port int32 + want uint16 + wantErr bool + }{ + {"valid min", 1, 1, false}, + {"valid mid", 8080, 8080, false}, + {"valid max", 65535, 65535, false}, + {"zero", 0, 0, true}, + {"negative", -1, 0, true}, + {"too large", 65536, 0, true}, + {"way too large", 100000, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ValidatePort(tt.port) + if tt.wantErr { + assert.Error(t, err) + assert.Zero(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestIsExpectedError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"net.ErrClosed", net.ErrClosed, true}, + {"context.Canceled", context.Canceled, true}, + {"io.EOF", io.EOF, true}, + {"ECONNRESET", syscall.ECONNRESET, true}, + {"EPIPE", syscall.EPIPE, true}, + {"ECONNABORTED", syscall.ECONNABORTED, true}, + {"wrapped expected", fmt.Errorf("wrap: %w", net.ErrClosed), true}, + {"unexpected EOF", io.ErrUnexpectedEOF, false}, + {"generic error", errors.New("something"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsExpectedError(tt.err)) + }) + } +} + +type timeoutErr struct{ timeout bool } + +func (e *timeoutErr) Error() string { return "timeout" } +func (e *timeoutErr) Timeout() bool { return e.timeout } +func (e *timeoutErr) Temporary() bool { return false } + +func TestIsTimeout(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"net timeout", &timeoutErr{timeout: true}, true}, + {"net non-timeout", &timeoutErr{timeout: false}, false}, + {"wrapped timeout", fmt.Errorf("wrap: %w", &timeoutErr{timeout: true}), true}, + {"generic error", errors.New("not a timeout"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsTimeout(tt.err)) + }) + } +} diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index 22ebbf371..4a61f6bcf 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "net/netip" "sync" "github.com/netbirdio/netbird/proxy/internal/types" @@ -47,10 +48,10 @@ func (o ResponseOrigin) String() string { type CapturedData struct { mu sync.RWMutex RequestID string - ServiceId string + ServiceId types.ServiceID AccountId types.AccountID Origin ResponseOrigin - ClientIP string + ClientIP netip.Addr UserID string AuthMethod string } @@ -63,14 +64,14 @@ func (c *CapturedData) GetRequestID() string { } // SetServiceId safely sets the service ID -func (c *CapturedData) SetServiceId(serviceId string) { +func (c *CapturedData) SetServiceId(serviceId types.ServiceID) { c.mu.Lock() defer c.mu.Unlock() c.ServiceId = serviceId } // GetServiceId safely gets the service ID -func (c *CapturedData) GetServiceId() string { +func (c *CapturedData) GetServiceId() types.ServiceID { c.mu.RLock() defer c.mu.RUnlock() return c.ServiceId @@ -105,14 +106,14 @@ func (c *CapturedData) GetOrigin() ResponseOrigin { } // SetClientIP safely sets the resolved client IP. -func (c *CapturedData) SetClientIP(ip string) { +func (c *CapturedData) SetClientIP(ip netip.Addr) { c.mu.Lock() defer c.mu.Unlock() c.ClientIP = ip } // GetClientIP safely gets the resolved client IP. -func (c *CapturedData) GetClientIP() string { +func (c *CapturedData) GetClientIP() netip.Addr { c.mu.RLock() defer c.mu.RUnlock() return c.ClientIP @@ -161,13 +162,13 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData { return data } -func withServiceId(ctx context.Context, serviceId string) context.Context { +func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context { return context.WithValue(ctx, serviceIdKey, serviceId) } -func ServiceIdFromContext(ctx context.Context) string { +func ServiceIdFromContext(ctx context.Context) types.ServiceID { v := ctx.Value(serviceIdKey) - serviceId, ok := v.(string) + serviceId, ok := v.(types.ServiceID) if !ok { return "" } diff --git a/proxy/internal/proxy/proxy_bench_test.go b/proxy/internal/proxy/proxy_bench_test.go index 5af2167e6..b59ef75c0 100644 --- a/proxy/internal/proxy/proxy_bench_test.go +++ b/proxy/internal/proxy/proxy_bench_test.go @@ -25,7 +25,7 @@ func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) { func BenchmarkServeHTTP(b *testing.B) { rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil) rp.AddMapping(proxy.Mapping{ - ID: rand.Text(), + ID: types.ServiceID(rand.Text()), AccountID: types.AccountID(rand.Text()), Host: "app.example.com", Paths: map[string]*proxy.PathTarget{ @@ -66,7 +66,7 @@ func BenchmarkServeHTTPHostCount(b *testing.B) { target = id } rp.AddMapping(proxy.Mapping{ - ID: id, + ID: types.ServiceID(id), AccountID: types.AccountID(rand.Text()), Host: host, Paths: map[string]*proxy.PathTarget{ @@ -118,7 +118,7 @@ func BenchmarkServeHTTPPathCount(b *testing.B) { } } rp.AddMapping(proxy.Mapping{ - ID: rand.Text(), + ID: types.ServiceID(rand.Text()), AccountID: types.AccountID(rand.Text()), Host: "app.example.com", Paths: paths, diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index b0001d5b9..1ee9b2a42 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" ) @@ -86,9 +87,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = roundtrip.WithSkipTLSVerify(ctx) } if pt.RequestTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout) - defer cancel() + ctx = types.WithDialTimeout(ctx, pt.RequestTimeout) } rewriteMatchedPath := result.matchedPath @@ -142,9 +141,9 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost r.Out.Header.Set(k, v) } - clientIP := extractClientIP(r.In.RemoteAddr) + clientIP := extractHostIP(r.In.RemoteAddr) - if IsTrustedProxy(clientIP, p.trustedProxies) { + if isTrustedAddr(clientIP, p.trustedProxies) { p.setTrustedForwardingHeaders(r, clientIP) } else { p.setUntrustedForwardingHeaders(r, clientIP) @@ -214,12 +213,14 @@ func normalizeHost(u *url.URL) string { // setTrustedForwardingHeaders appends to the existing forwarding header chain // and preserves upstream-provided headers when the direct connection is from // a trusted proxy. -func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { +func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) { + ipStr := clientIP.String() + // Append the direct connection IP to the existing X-Forwarded-For chain. if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" { - r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP) + r.Out.Header.Set("X-Forwarded-For", existing+", "+ipStr) } else { - r.Out.Header.Set("X-Forwarded-For", clientIP) + r.Out.Header.Set("X-Forwarded-For", ipStr) } // Preserve upstream X-Real-IP if present; otherwise resolve through the chain. @@ -227,7 +228,7 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli r.Out.Header.Set("X-Real-IP", realIP) } else { resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies) - r.Out.Header.Set("X-Real-IP", resolved) + r.Out.Header.Set("X-Real-IP", resolved.String()) } // Preserve upstream X-Forwarded-Host if present. @@ -257,10 +258,11 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli // sets them fresh based on the direct connection. This is the default // behavior when no trusted proxies are configured or the direct connection // is from an untrusted source. -func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { +func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) { + ipStr := clientIP.String() proto := auth.ResolveProto(p.forwardedProto, r.In.TLS) - r.Out.Header.Set("X-Forwarded-For", clientIP) - r.Out.Header.Set("X-Real-IP", clientIP) + r.Out.Header.Set("X-Forwarded-For", ipStr) + r.Out.Header.Set("X-Real-IP", ipStr) r.Out.Header.Set("X-Forwarded-Host", r.In.Host) r.Out.Header.Set("X-Forwarded-Proto", proto) r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto)) @@ -288,16 +290,6 @@ func stripSessionTokenQuery(r *httputil.ProxyRequest) { } } -// extractClientIP extracts the IP address from an http.Request.RemoteAddr -// which is always in host:port format. -func extractClientIP(remoteAddr string) string { - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return remoteAddr - } - return ip -} - // extractForwardedPort returns the port from the Host header if present, // otherwise defaults to the standard port for the resolved protocol. func extractForwardedPort(host, resolvedProto string) string { @@ -327,10 +319,12 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { web.ServeErrorPage(w, r, code, title, message, requestID, status) } -// getClientIP retrieves the resolved client IP from context. +// getClientIP retrieves the resolved client IP string from context. func getClientIP(r *http.Request) string { if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil { - return capturedData.GetClientIP() + if ip := capturedData.GetClientIP(); ip.IsValid() { + return ip.String() + } } return "" } diff --git a/proxy/internal/proxy/reverseproxy_test.go b/proxy/internal/proxy/reverseproxy_test.go index be2fb9105..b05ead198 100644 --- a/proxy/internal/proxy/reverseproxy_test.go +++ b/proxy/internal/proxy/reverseproxy_test.go @@ -284,23 +284,23 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { }) } -func TestExtractClientIP(t *testing.T) { +func TestExtractHostIP(t *testing.T) { tests := []struct { name string remoteAddr string - expected string + expected netip.Addr }{ - {"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"}, - {"IPv6 with port", "[::1]:12345", "::1"}, - {"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"}, - {"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"}, - {"IPv6 without brackets fallback", "::1", "::1"}, - {"empty string fallback", "", ""}, - {"public IP", "203.0.113.50:9999", "203.0.113.50"}, + {"IPv4 with port", "192.168.1.1:12345", netip.MustParseAddr("192.168.1.1")}, + {"IPv6 with port", "[::1]:12345", netip.MustParseAddr("::1")}, + {"IPv6 full with port", "[2001:db8::1]:443", netip.MustParseAddr("2001:db8::1")}, + {"IPv4 without port fallback", "192.168.1.1", netip.MustParseAddr("192.168.1.1")}, + {"IPv6 without brackets fallback", "::1", netip.MustParseAddr("::1")}, + {"empty string fallback", "", netip.Addr{}}, + {"public IP", "203.0.113.50:9999", netip.MustParseAddr("203.0.113.50")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr)) + assert.Equal(t, tt.expected, extractHostIP(tt.remoteAddr)) }) } } diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index 58b92ff9e..1513fbe45 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -30,8 +30,9 @@ type PathTarget struct { CustomHeaders map[string]string } +// Mapping describes how a domain is routed by the HTTP reverse proxy. type Mapping struct { - ID string + ID types.ServiceID AccountID types.AccountID Host string Paths map[string]*PathTarget @@ -42,7 +43,7 @@ type Mapping struct { type targetResult struct { target *PathTarget matchedPath string - serviceID string + serviceID types.ServiceID accountID types.AccountID passHostHeader bool rewriteRedirects bool @@ -101,8 +102,13 @@ func (p *ReverseProxy) AddMapping(m Mapping) { p.mappings[m.Host] = m } -func (p *ReverseProxy) RemoveMapping(m Mapping) { +// RemoveMapping removes the mapping for the given host and reports whether it existed. +func (p *ReverseProxy) RemoveMapping(m Mapping) bool { p.mappingsMux.Lock() defer p.mappingsMux.Unlock() + if _, ok := p.mappings[m.Host]; !ok { + return false + } delete(p.mappings, m.Host) + return true } diff --git a/proxy/internal/proxy/trustedproxy.go b/proxy/internal/proxy/trustedproxy.go index ad9a5b6c0..0fe693f90 100644 --- a/proxy/internal/proxy/trustedproxy.go +++ b/proxy/internal/proxy/trustedproxy.go @@ -7,21 +7,11 @@ import ( // IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes. func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { - if len(trusted) == 0 { - return false - } - addr, err := netip.ParseAddr(ipStr) - if err != nil { + if err != nil || len(trusted) == 0 { return false } - - for _, prefix := range trusted { - if prefix.Contains(addr) { - return true - } - } - return false + return isTrustedAddr(addr.Unmap(), trusted) } // ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list. @@ -30,10 +20,10 @@ func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { // // If the trusted list is empty or remoteAddr is not trusted, it returns the // remoteAddr IP directly (ignoring any forwarding headers). -func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { - remoteIP := extractClientIP(remoteAddr) +func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) netip.Addr { + remoteIP := extractHostIP(remoteAddr) - if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) { + if len(trusted) == 0 || !isTrustedAddr(remoteIP, trusted) { return remoteIP } @@ -47,14 +37,45 @@ func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { if ip == "" { continue } - if !IsTrustedProxy(ip, trusted) { - return ip + addr, err := netip.ParseAddr(ip) + if err != nil { + continue + } + addr = addr.Unmap() + if !isTrustedAddr(addr, trusted) { + return addr } } // All IPs in XFF are trusted; return the leftmost as best guess. if first := strings.TrimSpace(parts[0]); first != "" { - return first + if addr, err := netip.ParseAddr(first); err == nil { + return addr.Unmap() + } } return remoteIP } + +// extractHostIP parses the IP from a host:port string and returns it unmapped. +func extractHostIP(hostPort string) netip.Addr { + if ap, err := netip.ParseAddrPort(hostPort); err == nil { + return ap.Addr().Unmap() + } + if addr, err := netip.ParseAddr(hostPort); err == nil { + return addr.Unmap() + } + return netip.Addr{} +} + +// isTrustedAddr checks if the given address falls within any of the trusted prefixes. +func isTrustedAddr(addr netip.Addr, trusted []netip.Prefix) bool { + if !addr.IsValid() { + return false + } + for _, prefix := range trusted { + if prefix.Contains(addr) { + return true + } + } + return false +} diff --git a/proxy/internal/proxy/trustedproxy_test.go b/proxy/internal/proxy/trustedproxy_test.go index 827b7babf..35ed1f5c2 100644 --- a/proxy/internal/proxy/trustedproxy_test.go +++ b/proxy/internal/proxy/trustedproxy_test.go @@ -48,77 +48,77 @@ func TestResolveClientIP(t *testing.T) { remoteAddr string xff string trusted []netip.Prefix - want string + want netip.Addr }{ { name: "empty trusted list returns RemoteAddr", remoteAddr: "203.0.113.50:9999", xff: "1.2.3.4", trusted: nil, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "untrusted RemoteAddr ignores XFF", remoteAddr: "203.0.113.50:9999", xff: "1.2.3.4, 10.0.0.1", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr with single client in XFF", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr walks past trusted entries in XFF", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50, 10.0.0.2, 172.16.0.5", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr", remoteAddr: "10.0.0.1:5000", xff: "", trusted: trusted, - want: "10.0.0.1", + want: netip.MustParseAddr("10.0.0.1"), }, { name: "all XFF IPs trusted returns leftmost", remoteAddr: "10.0.0.1:5000", xff: "10.0.0.2, 172.16.0.1, 10.0.0.3", trusted: trusted, - want: "10.0.0.2", + want: netip.MustParseAddr("10.0.0.2"), }, { name: "XFF with whitespace", remoteAddr: "10.0.0.1:5000", xff: " 203.0.113.50 , 10.0.0.2 ", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "XFF with empty segments", remoteAddr: "10.0.0.1:5000", xff: "203.0.113.50,,10.0.0.2", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "multi-hop with mixed trust", remoteAddr: "10.0.0.1:5000", xff: "8.8.8.8, 203.0.113.50, 172.16.0.1", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, { name: "RemoteAddr without port", remoteAddr: "10.0.0.1", xff: "203.0.113.50", trusted: trusted, - want: "203.0.113.50", + want: netip.MustParseAddr("203.0.113.50"), }, } for _, tt := range tests { diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 57770f4a5..e38e3dc4e 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "sync" "time" @@ -14,11 +15,12 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/embed" nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -26,7 +28,22 @@ import ( const deviceNamePrefix = "ingress-proxy-" // backendKey identifies a backend by its host:port from the target URL. -type backendKey = string +type backendKey string + +// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service) +// that holds a reference to an embedded NetBird client. Callers should use the +// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions. +type ServiceKey string + +// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service. +func DomainServiceKey(domain string) ServiceKey { + return ServiceKey("domain:" + domain) +} + +// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP). +func L4ServiceKey(id types.ServiceID) ServiceKey { + return ServiceKey("l4:" + id) +} var ( // ErrNoAccountID is returned when a request context is missing the account ID. @@ -39,24 +56,24 @@ var ( ErrTooManyInflight = errors.New("too many in-flight requests") ) -// domainInfo holds metadata about a registered domain. -type domainInfo struct { - serviceID string +// serviceInfo holds metadata about a registered service. +type serviceInfo struct { + serviceID types.ServiceID } -type domainNotification struct { - domain domain.Domain - serviceID string +type serviceNotification struct { + key ServiceKey + serviceID types.ServiceID } -// clientEntry holds an embedded NetBird client and tracks which domains use it. +// clientEntry holds an embedded NetBird client and tracks which services use it. type clientEntry struct { client *embed.Client transport *http.Transport // insecureTransport is a clone of transport with TLS verification disabled, // used when per-target skip_tls_verify is set. insecureTransport *http.Transport - domains map[domain.Domain]domainInfo + services map[ServiceKey]serviceInfo createdAt time.Time started bool // Per-backend in-flight limiting keyed by target host:port. @@ -93,12 +110,12 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo // ClientConfig holds configuration for the embedded NetBird client. type ClientConfig struct { MgmtAddr string - WGPort int + WGPort uint16 PreSharedKey string } type statusNotifier interface { - NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error + NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error } type managementClient interface { @@ -107,7 +124,7 @@ type managementClient interface { // NetBird provides an http.RoundTripper implementation // backed by underlying NetBird connections. -// Clients are keyed by AccountID, allowing multiple domains to share the same connection. +// Clients are keyed by AccountID, allowing multiple services to share the same connection. type NetBird struct { proxyID string proxyAddr string @@ -124,11 +141,11 @@ type NetBird struct { // ClientDebugInfo contains debug information about a client. type ClientDebugInfo struct { - AccountID types.AccountID - DomainCount int - Domains domain.List - HasClient bool - CreatedAt time.Time + AccountID types.AccountID + ServiceCount int + ServiceKeys []string + HasClient bool + CreatedAt time.Time } // accountIDContextKey is the context key for storing the account ID. @@ -137,37 +154,37 @@ type accountIDContextKey struct{} // skipTLSVerifyContextKey is the context key for requesting insecure TLS. type skipTLSVerifyContextKey struct{} -// AddPeer registers a domain for an account. If the account doesn't have a client yet, +// AddPeer registers a service for an account. If the account doesn't have a client yet, // one is created by authenticating with the management server using the provided token. -// Multiple domains can share the same client. -func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error { +// Multiple services can share the same client. +func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error { + si := serviceInfo{serviceID: serviceID} + n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { - // Client already exists for this account, just register the domain - entry.domains[d] = domainInfo{serviceID: serviceID} + entry.services[key] = si started := entry.started n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Debug("registered domain with existing client") + "account_id": accountID, + "service_key": key, + }).Debug("registered service with existing client") - // If client is already started, notify this domain as connected immediately if started && n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil { + if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, + "account_id": accountID, + "service_key": key, }).WithError(err).Warn("failed to notify status for existing client") } } return nil } - entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID) + entry, err := n.createClientEntry(ctx, accountID, key, authToken, si) if err != nil { n.clientsMux.Unlock() return err @@ -177,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, + "account_id": accountID, + "service_key": key, }).Info("created new client for account") // Attempt to start the client in the background; if this fails we will @@ -190,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma // createClientEntry generates a WireGuard keypair, authenticates with management, // and creates an embedded NetBird client. Must be called with clientsMux held. -func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) { +func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) { + serviceID := si.serviceID n.logger.WithFields(log.Fields{ "account_id": accountID, "service_id": serviceID, @@ -209,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account }).Debug("authenticating new proxy peer with management") resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{ - ServiceId: serviceID, + ServiceId: string(serviceID), AccountId: string(accountID), Token: authToken, WireguardPublicKey: publicKey.String(), @@ -240,13 +258,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // Create embedded NetBird client with the generated private key. // The peer has already been created via CreateProxyPeer RPC with the public key. + wgPort := int(n.clientCfg.WGPort) client, err := embed.New(embed.Options{ DeviceName: deviceNamePrefix + n.proxyID, ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), BlockInbound: true, - WireguardPort: &n.clientCfg.WGPort, + WireguardPort: &wgPort, PreSharedKey: n.clientCfg.PreSharedKey, }) if err != nil { @@ -257,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // the client's HTTPClient to avoid issues with request validation that do // not work with reverse proxied requests. transport := &http.Transport{ - DialContext: client.DialContext, + DialContext: dialWithTimeout(client.DialContext), ForceAttemptHTTP2: true, MaxIdleConns: n.transportCfg.maxIdleConns, MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost, @@ -276,7 +295,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account return &clientEntry{ client: client, - domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}}, + services: map[ServiceKey]serviceInfo{key: si}, transport: transport, insecureTransport: insecureTransport, createdAt: time.Now(), @@ -286,7 +305,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account }, nil } -// runClientStartup starts the client and notifies registered domains on success. +// runClientStartup starts the client and notifies registered services on success. func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) { startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -300,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI return } - // Mark client as started and collect domains to notify outside the lock. + // Mark client as started and collect services to notify outside the lock. n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { entry.started = true } - var domainsToNotify []domainNotification + var toNotify []serviceNotification if exists { - for dom, info := range entry.domains { - domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID}) + for key, info := range entry.services { + toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID}) } } n.clientsMux.Unlock() @@ -317,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI if n.statusNotifier == nil { return } - for _, dn := range domainsToNotify { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil { + for _, sn := range toNotify { + if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": dn.domain, + "account_id": accountID, + "service_key": sn.key, }).WithError(err).Warn("failed to notify tunnel connection status") } else { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": dn.domain, + "account_id": accountID, + "service_key": sn.key, }).Info("notified management about tunnel connection") } } } -// RemovePeer unregisters a domain from an account. The client is only stopped -// when no domains are using it anymore. -func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error { +// RemovePeer unregisters a service from an account. The client is only stopped +// when no services are using it anymore. +func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error { n.clientsMux.Lock() entry, exists := n.clients[accountID] @@ -344,74 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d return nil } - // Get domain info before deleting - domInfo, domainExists := entry.domains[d] - if !domainExists { + si, svcExists := entry.services[key] + if !svcExists { n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).Debug("remove peer: domain not registered") + "account_id": accountID, + "service_key": key, + }).Debug("remove peer: service not registered") return nil } - delete(entry.domains, d) - - // If there are still domains using this client, keep it running - if len(entry.domains) > 0 { - n.clientsMux.Unlock() + delete(entry.services, key) + stopClient := len(entry.services) == 0 + var client *embed.Client + var transport, insecureTransport *http.Transport + if stopClient { + n.logger.WithField("account_id", accountID).Info("stopping client, no more services") + client = entry.client + transport = entry.transport + insecureTransport = entry.insecureTransport + delete(n.clients, accountID) + } else { n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - "remaining_domains": len(entry.domains), - }).Debug("unregistered domain, client still in use") - - // Notify this domain as disconnected - if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).WithError(err).Warn("failed to notify tunnel disconnection status") - } - } - return nil + "account_id": accountID, + "service_key": key, + "remaining_services": len(entry.services), + }).Debug("unregistered service, client still in use") } - - // No more domains using this client, stop it - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).Info("stopping client, no more domains") - - client := entry.client - transport := entry.transport - insecureTransport := entry.insecureTransport - delete(n.clients, accountID) n.clientsMux.Unlock() - // Notify disconnection before stopping - if n.statusNotifier != nil { - if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - "domain": d, - }).WithError(err).Warn("failed to notify tunnel disconnection status") + n.notifyDisconnect(ctx, accountID, key, si.serviceID) + + if stopClient { + transport.CloseIdleConnections() + insecureTransport.CloseIdleConnections() + if err := client.Stop(ctx); err != nil { + n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client") } } - transport.CloseIdleConnections() - insecureTransport.CloseIdleConnections() - - if err := client.Stop(ctx); err != nil { - n.logger.WithFields(log.Fields{ - "account_id": accountID, - }).WithError(err).Warn("failed to stop netbird client") - } - return nil } +func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) { + if n.statusNotifier == nil { + return + } + if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil { + if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound { + n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification") + } else { + n.logger.WithFields(log.Fields{ + "account_id": accountID, + "service_key": key, + }).WithError(err).Warn("failed to notify tunnel disconnection status") + } + } +} + // RoundTrip implements http.RoundTripper. It looks up the client for the account // specified in the request context and uses it to dial the backend. func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { @@ -435,7 +445,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { } n.clientsMux.RUnlock() - release, ok := entry.acquireInflight(req.URL.Host) + release, ok := entry.acquireInflight(backendKey(req.URL.Host)) defer release() if !ok { return nil, ErrTooManyInflight @@ -496,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool { return exists } -// DomainCount returns the number of domains registered for the given account. +// ServiceCount returns the number of services registered for the given account. // Returns 0 if the account has no client. -func (n *NetBird) DomainCount(accountID types.AccountID) int { +func (n *NetBird) ServiceCount(accountID types.AccountID) int { n.clientsMux.RLock() defer n.clientsMux.RUnlock() entry, exists := n.clients[accountID] if !exists { return 0 } - return len(entry.domains) + return len(entry.services) } // ClientCount returns the total number of active clients. @@ -533,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { result := make(map[types.AccountID]ClientDebugInfo) for accountID, entry := range n.clients { - domains := make(domain.List, 0, len(entry.domains)) - for d := range entry.domains { - domains = append(domains, d) + keys := make([]string, 0, len(entry.services)) + for k := range entry.services { + keys = append(keys, string(k)) } result[accountID] = ClientDebugInfo{ - AccountID: accountID, - DomainCount: len(entry.domains), - Domains: domains, - HasClient: entry.client != nil, - CreatedAt: entry.createdAt, + AccountID: accountID, + ServiceCount: len(entry.services), + ServiceKeys: keys, + HasClient: entry.client != nil, + CreatedAt: entry.createdAt, } } return result @@ -581,6 +591,20 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L } } +// dialWithTimeout wraps a DialContext function so that any dial timeout +// stored in the context (via types.WithDialTimeout) is applied only to +// the connection establishment phase, not the full request lifetime. +func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + if d, ok := types.DialTimeoutFromContext(ctx); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, d) + defer cancel() + } + return dial(ctx, network, addr) + } +} + // WithAccountID adds the account ID to the context. func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context { return context.WithValue(ctx, accountIDContextKey{}, accountID) diff --git a/proxy/internal/roundtrip/netbird_bench_test.go b/proxy/internal/roundtrip/netbird_bench_test.go index e89213c33..330ea0332 100644 --- a/proxy/internal/roundtrip/netbird_bench_test.go +++ b/proxy/internal/roundtrip/netbird_bench_test.go @@ -1,6 +1,7 @@ package roundtrip import ( + "context" "crypto/rand" "math/big" "sync" @@ -8,7 +9,6 @@ import ( "time" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" ) // Simple benchmark for comparison with AddPeer contention. @@ -29,9 +29,9 @@ func BenchmarkHasClient(b *testing.B) { target = id } nb.clients[id] = &clientEntry{ - domains: map[domain.Domain]domainInfo{ - domain.Domain(rand.Text()): { - serviceID: rand.Text(), + services: map[ServiceKey]serviceInfo{ + ServiceKey(rand.Text()): { + serviceID: types.ServiceID(rand.Text()), }, }, createdAt: time.Now(), @@ -70,9 +70,9 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { target = id } nb.clients[id] = &clientEntry{ - domains: map[domain.Domain]domainInfo{ - domain.Domain(rand.Text()): { - serviceID: rand.Text(), + services: map[ServiceKey]serviceInfo{ + ServiceKey(rand.Text()): { + serviceID: types.ServiceID(rand.Text()), }, }, createdAt: time.Now(), @@ -81,19 +81,22 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { } // Launch workers that continuously call AddPeer with new random accountIDs. + ctx, cancel := context.WithCancel(b.Context()) var wg sync.WaitGroup for range addPeerWorkers { - wg.Go(func() { - for { - if err := nb.AddPeer(b.Context(), + wg.Add(1) + go func() { + defer wg.Done() + for ctx.Err() == nil { + if err := nb.AddPeer(ctx, types.AccountID(rand.Text()), - domain.Domain(rand.Text()), + ServiceKey(rand.Text()), rand.Text(), - rand.Text()); err != nil { - b.Log(err) + types.ServiceID(rand.Text())); err != nil { + return } } - }) + }() } // Benchmark calling HasClient during AddPeer contention. @@ -104,4 +107,6 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) { } }) b.StopTimer() + cancel() + wg.Wait() } diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go index 0a742c2fa..5444f6c11 100644 --- a/proxy/internal/roundtrip/netbird_test.go +++ b/proxy/internal/roundtrip/netbird_test.go @@ -11,7 +11,6 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -27,16 +26,15 @@ type mockStatusNotifier struct { } type statusCall struct { - accountID string - serviceID string - domain string + accountID types.AccountID + serviceID types.ServiceID connected bool } -func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error { +func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { m.mu.Lock() defer m.mu.Unlock() - m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected}) + m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected}) return nil } @@ -62,36 +60,34 @@ func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { // Initially no client exists. assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer") - assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0") - // Add first domain - this should create a new client. - // Note: This will fail to actually connect since we use an invalid URL, - // but the client entry should still be created. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add first service - this should create a new client. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.True(t, nb.HasClient(accountID), "should have client after AddPeer") - assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") + assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1") } func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add first domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add first service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - assert.Equal(t, 1, nb.DomainCount(accountID)) + assert.Equal(t, 1, nb.ServiceCount(accountID)) - // Add second domain for the same account - should reuse existing client. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + // Add second service for the same account - should reuse existing client. + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2")) require.NoError(t, err) - assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain") + assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2 after adding second service") - // Add third domain. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + // Add third service. + err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3")) require.NoError(t, err) - assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain") + assert.Equal(t, 3, nb.ServiceCount(accountID), "service count should be 3 after adding third service") // Still only one client. assert.True(t, nb.HasClient(accountID)) @@ -102,64 +98,62 @@ func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) { account1 := types.AccountID("account-1") account2 := types.AccountID("account-2") - // Add domain for account 1. - err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add service for account 1. + err := nb.AddPeer(context.Background(), account1, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - // Add domain for account 2. - err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2") + // Add service for account 2. + err = nb.AddPeer(context.Background(), account2, "domain2.test", "setup-key-2", types.ServiceID("proxy-2")) require.NoError(t, err) // Both accounts should have their own clients. assert.True(t, nb.HasClient(account1), "account1 should have client") assert.True(t, nb.HasClient(account2), "account2 should have client") - assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1") - assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1") + assert.Equal(t, 1, nb.ServiceCount(account1), "account1 service count should be 1") + assert.Equal(t, 1, nb.ServiceCount(account2), "account2 service count should be 1") } -func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) { +func TestNetBird_RemovePeer_KeepsClientWhenServicesRemain(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add multiple domains. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add multiple services. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2") + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3") + err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3")) require.NoError(t, err) - assert.Equal(t, 3, nb.DomainCount(accountID)) + assert.Equal(t, 3, nb.ServiceCount(accountID)) - // Remove one domain - client should remain. + // Remove one service - client should remain. err = nb.RemovePeer(context.Background(), accountID, "domain1.test") require.NoError(t, err) - assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain") - assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2") + assert.True(t, nb.HasClient(accountID), "client should remain after removing one service") + assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2") - // Remove another domain - client should still remain. + // Remove another service - client should still remain. err = nb.RemovePeer(context.Background(), accountID, "domain2.test") require.NoError(t, err) - assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain") - assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1") + assert.True(t, nb.HasClient(accountID), "client should remain after removing second service") + assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1") } -func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) { +func TestNetBird_RemovePeer_RemovesClientWhenLastServiceRemoved(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add single domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add single service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.True(t, nb.HasClient(accountID)) - // Remove the only domain - client should be removed. - // Note: Stop() may fail since the client never actually connected, - // but the entry should still be removed from the map. + // Remove the only service - client should be removed. _ = nb.RemovePeer(context.Background(), accountID, "domain1.test") - // After removing all domains, client should be gone. - assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain") - assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0") + // After removing all services, client should be gone. + assert.False(t, nb.HasClient(accountID), "client should be removed after removing last service") + assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0") } func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { @@ -171,21 +165,21 @@ func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) { assert.NoError(t, err, "removing from non-existent account should not error") } -func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) { +func TestNetBird_RemovePeer_NonExistentServiceIsNoop(t *testing.T) { nb := mockNetBird() accountID := types.AccountID("account-1") - // Add one domain. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1") + // Add one service. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - // Remove non-existent domain - should not affect existing domain. - err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test")) + // Remove non-existent service - should not affect existing service. + err = nb.RemovePeer(context.Background(), accountID, "nonexistent.test") require.NoError(t, err) - // Original domain should still be registered. + // Original service should still be registered. assert.True(t, nb.HasClient(accountID)) - assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain") + assert.Equal(t, 1, nb.ServiceCount(accountID), "original service should remain") } func TestWithAccountID_AndAccountIDFromContext(t *testing.T) { @@ -216,19 +210,17 @@ func TestNetBird_StopAll_StopsAllClients(t *testing.T) { account2 := types.AccountID("account-2") account3 := types.AccountID("account-3") - // Add domains for multiple accounts. - err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1") + // Add services for multiple accounts. + err := nb.AddPeer(context.Background(), account1, "domain1.test", "key-1", types.ServiceID("proxy-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2") + err = nb.AddPeer(context.Background(), account2, "domain2.test", "key-2", types.ServiceID("proxy-2")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3") + err = nb.AddPeer(context.Background(), account3, "domain3.test", "key-3", types.ServiceID("proxy-3")) require.NoError(t, err) assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients") // Stop all clients. - // Note: StopAll may return errors since clients never actually connected, - // but the clients should still be removed from the map. _ = nb.StopAll(context.Background()) assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll") @@ -243,18 +235,18 @@ func TestNetBird_ClientCount(t *testing.T) { assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients") // Add clients for different accounts. - err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1") + err := nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1.test", "key-1", types.ServiceID("proxy-1")) require.NoError(t, err) assert.Equal(t, 1, nb.ClientCount()) - err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2") + err = nb.AddPeer(context.Background(), types.AccountID("account-2"), "domain2.test", "key-2", types.ServiceID("proxy-2")) require.NoError(t, err) assert.Equal(t, 2, nb.ClientCount()) - // Adding domain to existing account should not increase count. - err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b") + // Adding service to existing account should not increase count. + err = nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1b.test", "key-1", types.ServiceID("proxy-1b")) require.NoError(t, err) - assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count") + assert.Equal(t, 2, nb.ClientCount(), "adding service to existing account should not increase client count") } func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) { @@ -293,8 +285,8 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") - // Add first domain — creates a new client entry. - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + // Add first service — creates a new client entry. + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1")) require.NoError(t, err) // Manually mark client as started to simulate background startup completing. @@ -302,15 +294,14 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { nb.clients[accountID].started = true nb.clientsMux.Unlock() - // Add second domain — should notify immediately since client is already started. - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + // Add second service — should notify immediately since client is already started. + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2")) require.NoError(t, err) calls := notifier.calls() require.Len(t, calls, 1) - assert.Equal(t, string(accountID), calls[0].accountID) - assert.Equal(t, "svc-2", calls[0].serviceID) - assert.Equal(t, "domain2.test", calls[0].domain) + assert.Equal(t, accountID, calls[0].accountID) + assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID) assert.True(t, calls[0].connected) } @@ -323,18 +314,18 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") - err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") + err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1")) require.NoError(t, err) - err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2") + err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2")) require.NoError(t, err) - // Remove one domain — client stays, but disconnection notification fires. + // Remove one service — client stays, but disconnection notification fires. err = nb.RemovePeer(context.Background(), accountID, "domain1.test") require.NoError(t, err) assert.True(t, nb.HasClient(accountID)) calls := notifier.calls() require.Len(t, calls, 1) - assert.Equal(t, "domain1.test", calls[0].domain) + assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID) assert.False(t, calls[0].connected) } diff --git a/proxy/internal/tcp/bench_test.go b/proxy/internal/tcp/bench_test.go new file mode 100644 index 000000000..049f8395d --- /dev/null +++ b/proxy/internal/tcp/bench_test.go @@ -0,0 +1,133 @@ +package tcp + +import ( + "bytes" + "crypto/tls" + "io" + "net" + "testing" +) + +// BenchmarkPeekClientHello_TLS measures the overhead of peeking at a real +// TLS ClientHello and extracting the SNI. This is the per-connection cost +// added to every TLS connection on the main listener. +func BenchmarkPeekClientHello_TLS(b *testing.B) { + // Pre-generate a ClientHello by capturing what crypto/tls sends. + clientConn, serverConn := net.Pipe() + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + var hello []byte + buf := make([]byte, 16384) + n, _ := serverConn.Read(buf) + hello = make([]byte, n) + copy(hello, buf[:n]) + clientConn.Close() + serverConn.Close() + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(hello) + conn := &readerConn{Reader: r} + sni, wrapped, err := PeekClientHello(conn) + if err != nil { + b.Fatal(err) + } + if sni != "app.example.com" { + b.Fatalf("unexpected SNI: %q", sni) + } + // Simulate draining the peeked bytes (what the HTTP server would do). + _, _ = io.Copy(io.Discard, wrapped) + } +} + +// BenchmarkPeekClientHello_NonTLS measures peek overhead for non-TLS +// connections that hit the fast non-handshake exit path. +func BenchmarkPeekClientHello_NonTLS(b *testing.B) { + httpReq := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(httpReq) + conn := &readerConn{Reader: r} + _, wrapped, err := PeekClientHello(conn) + if err != nil { + b.Fatal(err) + } + _, _ = io.Copy(io.Discard, wrapped) + } +} + +// BenchmarkPeekedConn_Read measures the read overhead of the peekedConn +// wrapper compared to a plain connection read. The peeked bytes use +// io.MultiReader which adds one indirection per Read call. +func BenchmarkPeekedConn_Read(b *testing.B) { + data := make([]byte, 4096) + peeked := make([]byte, 512) + buf := make([]byte, 1024) + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + r := bytes.NewReader(data) + conn := &readerConn{Reader: r} + pc := newPeekedConn(conn, peeked) + for { + _, err := pc.Read(buf) + if err != nil { + break + } + } + } +} + +// BenchmarkExtractSNI measures just the in-memory SNI parsing cost, +// excluding I/O. +func BenchmarkExtractSNI(b *testing.B) { + clientConn, serverConn := net.Pipe() + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + buf := make([]byte, 16384) + n, _ := serverConn.Read(buf) + payload := make([]byte, n-tlsRecordHeaderLen) + copy(payload, buf[tlsRecordHeaderLen:n]) + clientConn.Close() + serverConn.Close() + + b.ResetTimer() + b.ReportAllocs() + + for b.Loop() { + sni := extractSNI(payload) + if sni != "app.example.com" { + b.Fatalf("unexpected SNI: %q", sni) + } + } +} + +// readerConn wraps an io.Reader as a net.Conn for benchmarking. +// Only Read is functional; all other methods are no-ops. +type readerConn struct { + io.Reader + net.Conn +} + +func (c *readerConn) Read(b []byte) (int, error) { + return c.Reader.Read(b) +} diff --git a/proxy/internal/tcp/chanlistener.go b/proxy/internal/tcp/chanlistener.go new file mode 100644 index 000000000..ee64bc0a2 --- /dev/null +++ b/proxy/internal/tcp/chanlistener.go @@ -0,0 +1,76 @@ +package tcp + +import ( + "net" + "sync" +) + +// chanListener implements net.Listener by reading connections from a channel. +// It allows the SNI router to feed HTTP connections to http.Server.ServeTLS. +type chanListener struct { + ch chan net.Conn + addr net.Addr + once sync.Once + closed chan struct{} +} + +func newChanListener(ch chan net.Conn, addr net.Addr) *chanListener { + return &chanListener{ + ch: ch, + addr: addr, + closed: make(chan struct{}), + } +} + +// Accept waits for and returns the next connection from the channel. +func (l *chanListener) Accept() (net.Conn, error) { + for { + select { + case conn, ok := <-l.ch: + if !ok { + return nil, net.ErrClosed + } + return conn, nil + case <-l.closed: + // Drain buffered connections before returning. + for { + select { + case conn, ok := <-l.ch: + if !ok { + return nil, net.ErrClosed + } + _ = conn.Close() + default: + return nil, net.ErrClosed + } + } + } + } +} + +// Close signals the listener to stop accepting connections and drains +// any buffered connections that have not yet been accepted. +func (l *chanListener) Close() error { + l.once.Do(func() { + close(l.closed) + for { + select { + case conn, ok := <-l.ch: + if !ok { + return + } + _ = conn.Close() + default: + return + } + } + }) + return nil +} + +// Addr returns the listener's network address. +func (l *chanListener) Addr() net.Addr { + return l.addr +} + +var _ net.Listener = (*chanListener)(nil) diff --git a/proxy/internal/tcp/peekedconn.go b/proxy/internal/tcp/peekedconn.go new file mode 100644 index 000000000..26f3e5c7c --- /dev/null +++ b/proxy/internal/tcp/peekedconn.go @@ -0,0 +1,39 @@ +package tcp + +import ( + "bytes" + "io" + "net" +) + +// peekedConn wraps a net.Conn and prepends previously peeked bytes +// so that readers see the full original stream transparently. +type peekedConn struct { + net.Conn + reader io.Reader +} + +func newPeekedConn(conn net.Conn, peeked []byte) *peekedConn { + return &peekedConn{ + Conn: conn, + reader: io.MultiReader(bytes.NewReader(peeked), conn), + } +} + +// Read replays the peeked bytes first, then reads from the underlying conn. +func (c *peekedConn) Read(b []byte) (int, error) { + return c.reader.Read(b) +} + +// CloseWrite delegates to the underlying connection if it supports +// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn +// as an interface hides the concrete type's CloseWrite method, making +// half-close a silent no-op for all SNI-routed connections. +func (c *peekedConn) CloseWrite() error { + if hc, ok := c.Conn.(halfCloser); ok { + return hc.CloseWrite() + } + return nil +} + +var _ halfCloser = (*peekedConn)(nil) diff --git a/proxy/internal/tcp/proxyprotocol.go b/proxy/internal/tcp/proxyprotocol.go new file mode 100644 index 000000000..699b75a5d --- /dev/null +++ b/proxy/internal/tcp/proxyprotocol.go @@ -0,0 +1,29 @@ +package tcp + +import ( + "fmt" + "net" + + "github.com/pires/go-proxyproto" +) + +// writeProxyProtoV2 sends a PROXY protocol v2 header to the backend connection, +// conveying the real client address. +func writeProxyProtoV2(client, backend net.Conn) error { + tp := proxyproto.TCPv4 + if addr, ok := client.RemoteAddr().(*net.TCPAddr); ok && addr.IP.To4() == nil { + tp = proxyproto.TCPv6 + } + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: tp, + SourceAddr: client.RemoteAddr(), + DestinationAddr: client.LocalAddr(), + } + if _, err := header.WriteTo(backend); err != nil { + return fmt.Errorf("write PROXY protocol v2 header: %w", err) + } + return nil +} diff --git a/proxy/internal/tcp/proxyprotocol_test.go b/proxy/internal/tcp/proxyprotocol_test.go new file mode 100644 index 000000000..f8c48b2ab --- /dev/null +++ b/proxy/internal/tcp/proxyprotocol_test.go @@ -0,0 +1,128 @@ +package tcp + +import ( + "bufio" + "net" + "testing" + + "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWriteProxyProtoV2_IPv4(t *testing.T) { + // Set up a real TCP listener and dial to get connections with real addresses. + ln, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + var serverConn net.Conn + accepted := make(chan struct{}) + go func() { + var err error + serverConn, err = ln.Accept() + if err != nil { + t.Error("accept failed:", err) + } + close(accepted) + }() + + clientConn, err := net.Dial("tcp4", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-accepted + defer serverConn.Close() + + // Use a pipe as the backend: write the header to one end, read from the other. + backendRead, backendWrite := net.Pipe() + defer backendRead.Close() + defer backendWrite.Close() + + // serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination. + writeDone := make(chan error, 1) + go func() { + writeDone <- writeProxyProtoV2(serverConn, backendWrite) + }() + + // Read the PROXY protocol header from the backend read side. + header, err := proxyproto.Read(bufio.NewReader(backendRead)) + require.NoError(t, err) + require.NotNil(t, header, "should have received a proxy protocol header") + + writeErr := <-writeDone + require.NoError(t, writeErr) + + assert.Equal(t, byte(2), header.Version, "version should be 2") + assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY") + assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4") + + // serverConn.RemoteAddr() is the client's address (source in the header). + expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr) + actualSrc := header.SourceAddr.(*net.TCPAddr) + assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr") + assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr") + + // serverConn.LocalAddr() is the server's address (destination in the header). + expectedDst := serverConn.LocalAddr().(*net.TCPAddr) + actualDst := header.DestinationAddr.(*net.TCPAddr) + assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr") + assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr") +} + +func TestWriteProxyProtoV2_IPv6(t *testing.T) { + // Set up a real TCP6 listener on loopback. + ln, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skip("IPv6 not available:", err) + } + defer ln.Close() + + var serverConn net.Conn + accepted := make(chan struct{}) + go func() { + var err error + serverConn, err = ln.Accept() + if err != nil { + t.Error("accept failed:", err) + } + close(accepted) + }() + + clientConn, err := net.Dial("tcp6", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-accepted + defer serverConn.Close() + + backendRead, backendWrite := net.Pipe() + defer backendRead.Close() + defer backendWrite.Close() + + writeDone := make(chan error, 1) + go func() { + writeDone <- writeProxyProtoV2(serverConn, backendWrite) + }() + + header, err := proxyproto.Read(bufio.NewReader(backendRead)) + require.NoError(t, err) + require.NotNil(t, header, "should have received a proxy protocol header") + + writeErr := <-writeDone + require.NoError(t, writeErr) + + assert.Equal(t, byte(2), header.Version, "version should be 2") + assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY") + assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6") + + expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr) + actualSrc := header.SourceAddr.(*net.TCPAddr) + assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr") + assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr") + + expectedDst := serverConn.LocalAddr().(*net.TCPAddr) + actualDst := header.DestinationAddr.(*net.TCPAddr) + assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr") + assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr") +} diff --git a/proxy/internal/tcp/relay.go b/proxy/internal/tcp/relay.go new file mode 100644 index 000000000..39949818d --- /dev/null +++ b/proxy/internal/tcp/relay.go @@ -0,0 +1,156 @@ +package tcp + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/netutil" +) + +// errIdleTimeout is returned when a relay connection is closed due to inactivity. +var errIdleTimeout = errors.New("idle timeout") + +// DefaultIdleTimeout is the default idle timeout for TCP relay connections. +// A zero value disables idle timeout checking. +const DefaultIdleTimeout = 5 * time.Minute + +// halfCloser is implemented by connections that support half-close +// (e.g. *net.TCPConn). When one copy direction finishes, we signal +// EOF to the remote by closing the write side while keeping the read +// side open so the other direction can drain. +type halfCloser interface { + CloseWrite() error +} + +// copyBufPool avoids allocating a new 32KB buffer per io.Copy call. +var copyBufPool = sync.Pool{ + New: func() any { + buf := make([]byte, 32*1024) + return &buf + }, +} + +// Relay copies data bidirectionally between src and dst until both +// sides are done or the context is canceled. When idleTimeout is +// non-zero, each direction's read is deadline-guarded; if no data +// flows within the timeout the connection is torn down. When one +// direction finishes, it half-closes the write side of the +// destination (if supported) to signal EOF, allowing the other +// direction to drain gracefully before the full connection teardown. +func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + _ = src.Close() + _ = dst.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + var errSrcToDst, errDstToSrc error + + go func() { + defer wg.Done() + srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout) + halfClose(dst) + cancel() + }() + + go func() { + defer wg.Done() + dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout) + halfClose(src) + cancel() + }() + + wg.Wait() + + if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) { + logger.Debug("relay closed due to idle timeout") + } + if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) { + logger.Debugf("relay copy error (src→dst): %v", errSrcToDst) + } + if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) { + logger.Debugf("relay copy error (dst→src): %v", errDstToSrc) + } + + return srcToDst, dstToSrc +} + +// copyWithIdleTimeout copies from src to dst using a pooled buffer. +// When idleTimeout > 0 it sets a read deadline on src before each +// read and treats a timeout as an idle-triggered close. +func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) { + bufp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bufp) + + if idleTimeout <= 0 { + return io.CopyBuffer(dst, src, *bufp) + } + + conn, ok := src.(net.Conn) + if !ok { + return io.CopyBuffer(dst, src, *bufp) + } + + buf := *bufp + var total int64 + for { + if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { + return total, err + } + nr, readErr := src.Read(buf) + if nr > 0 { + n, err := checkedWrite(dst, buf[:nr]) + total += n + if err != nil { + return total, err + } + } + if readErr != nil { + if netutil.IsTimeout(readErr) { + return total, errIdleTimeout + } + return total, readErr + } + } +} + +// checkedWrite writes buf to dst and returns the number of bytes written. +// It guards against short writes and negative counts per io.Copy convention. +func checkedWrite(dst io.Writer, buf []byte) (int64, error) { + nw, err := dst.Write(buf) + if nw < 0 || nw > len(buf) { + nw = 0 + } + if err != nil { + return int64(nw), err + } + if nw != len(buf) { + return int64(nw), io.ErrShortWrite + } + return int64(nw), nil +} + +func isExpectedCopyError(err error) bool { + return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err) +} + +// halfClose attempts to half-close the write side of the connection. +// If the connection does not support half-close, this is a no-op. +func halfClose(conn net.Conn) { + if hc, ok := conn.(halfCloser); ok { + // Best-effort; the full close will follow shortly. + _ = hc.CloseWrite() + } +} diff --git a/proxy/internal/tcp/relay_test.go b/proxy/internal/tcp/relay_test.go new file mode 100644 index 000000000..e42d65b9d --- /dev/null +++ b/proxy/internal/tcp/relay_test.go @@ -0,0 +1,210 @@ +package tcp + +import ( + "context" + "fmt" + "io" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/netutil" +) + +func TestRelay_BidirectionalCopy(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + srcData := []byte("hello from src") + dstData := []byte("hello from dst") + + // dst side: write response first, then read + close. + go func() { + _, _ = dstClient.Write(dstData) + buf := make([]byte, 256) + _, _ = dstClient.Read(buf) + dstClient.Close() + }() + + // src side: read the response, then send data + close. + go func() { + buf := make([]byte, 256) + _, _ = srcClient.Read(buf) + _, _ = srcClient.Write(srcData) + srcClient.Close() + }() + + s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0) + + assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst") + assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src") +} + +func TestRelay_ContextCancellation(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + defer srcClient.Close() + defer dstClient.Close() + + logger := log.NewEntry(log.StandardLogger()) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + Relay(ctx, logger, srcServer, dstServer, 0) + close(done) + }() + + // Cancel should cause Relay to return. + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Relay did not return after context cancellation") + } +} + +func TestRelay_OneSideClosed(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + defer dstClient.Close() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // Close src immediately. Relay should complete without hanging. + srcClient.Close() + + done := make(chan struct{}) + go func() { + Relay(ctx, logger, srcServer, dstServer, 0) + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Relay did not return after one side closed") + } +} + +func TestRelay_LargeTransfer(t *testing.T) { + srcClient, srcServer := net.Pipe() + dstClient, dstServer := net.Pipe() + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // 1MB of data. + data := make([]byte, 1<<20) + for i := range data { + data[i] = byte(i % 256) + } + + go func() { + _, _ = srcClient.Write(data) + srcClient.Close() + }() + + errCh := make(chan error, 1) + go func() { + received, err := io.ReadAll(dstClient) + if err != nil { + errCh <- err + return + } + if len(received) != len(data) { + errCh <- fmt.Errorf("expected %d bytes, got %d", len(data), len(received)) + return + } + errCh <- nil + dstClient.Close() + }() + + s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0) + assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes") + require.NoError(t, <-errCh) +} + +func TestRelay_IdleTimeout(t *testing.T) { + // Use real TCP connections so SetReadDeadline works (net.Pipe + // does not support deadlines). + srcLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer srcLn.Close() + + dstLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer dstLn.Close() + + srcClient, err := net.Dial("tcp", srcLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer srcClient.Close() + + srcServer, err := srcLn.Accept() + if err != nil { + t.Fatal(err) + } + + dstClient, err := net.Dial("tcp", dstLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer dstClient.Close() + + dstServer, err := dstLn.Accept() + if err != nil { + t.Fatal(err) + } + + logger := log.NewEntry(log.StandardLogger()) + ctx := context.Background() + + // Send initial data to prove the relay works. + go func() { + _, _ = srcClient.Write([]byte("ping")) + }() + + done := make(chan struct{}) + var s2d, d2s int64 + go func() { + s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) + close(done) + }() + + // Read the forwarded data on the dst side. + buf := make([]byte, 64) + n, err := dstClient.Read(buf) + assert.NoError(t, err) + assert.Equal(t, "ping", string(buf[:n])) + + // Now stop sending. The relay should close after the idle timeout. + select { + case <-done: + assert.Greater(t, s2d, int64(0), "should have transferred initial data") + _ = d2s + case <-time.After(5 * time.Second): + t.Fatal("Relay did not exit after idle timeout") + } +} + +func TestIsExpectedError(t *testing.T) { + assert.True(t, netutil.IsExpectedError(net.ErrClosed)) + assert.True(t, netutil.IsExpectedError(context.Canceled)) + assert.True(t, netutil.IsExpectedError(io.EOF)) + assert.False(t, netutil.IsExpectedError(io.ErrUnexpectedEOF)) +} diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go new file mode 100644 index 000000000..84fde0731 --- /dev/null +++ b/proxy/internal/tcp/router.go @@ -0,0 +1,570 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "slices" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +// defaultDialTimeout is the fallback dial timeout when no per-route +// timeout is configured. +const defaultDialTimeout = 30 * time.Second + +// SNIHost is a typed key for SNI hostname lookups. +type SNIHost string + +// RouteType specifies how a connection should be handled. +type RouteType int + +const ( + // RouteHTTP routes the connection through the HTTP reverse proxy. + RouteHTTP RouteType = iota + // RouteTCP relays the connection directly to the backend (TLS passthrough). + RouteTCP +) + +const ( + // sniPeekTimeout is the deadline for reading the TLS ClientHello. + sniPeekTimeout = 5 * time.Second + // DefaultDrainTimeout is the default grace period for in-flight relay + // connections to finish during shutdown. + DefaultDrainTimeout = 30 * time.Second + // DefaultMaxRelayConns is the default cap on concurrent TCP relay connections per router. + DefaultMaxRelayConns = 4096 + // httpChannelBuffer is the capacity of the channel feeding HTTP connections. + httpChannelBuffer = 4096 +) + +// DialResolver returns a DialContextFunc for the given account. +type DialResolver func(accountID types.AccountID) (types.DialContextFunc, error) + +// Route describes where a connection for a given SNI should be sent. +type Route struct { + Type RouteType + AccountID types.AccountID + ServiceID types.ServiceID + // Domain is the service's configured domain, used for access log entries. + Domain string + // Protocol is the frontend protocol (tcp, tls), used for access log entries. + Protocol accesslog.Protocol + // Target is the backend address for TCP relay (e.g. "10.0.0.5:5432"). + Target string + // ProxyProtocol enables sending a PROXY protocol v2 header to the backend. + ProxyProtocol bool + // DialTimeout overrides the default dial timeout for this route. + // Zero uses defaultDialTimeout. + DialTimeout time.Duration +} + +// l4Logger sends layer-4 access log entries to the management server. +type l4Logger interface { + LogL4(entry accesslog.L4Entry) +} + +// RelayObserver receives callbacks for TCP relay lifecycle events. +// All methods must be safe for concurrent use. +type RelayObserver interface { + TCPRelayStarted(accountID types.AccountID) + TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) + TCPRelayDialError(accountID types.AccountID) + TCPRelayRejected(accountID types.AccountID) +} + +// Router accepts raw TCP connections on a shared listener, peeks at +// the TLS ClientHello to extract the SNI, and routes the connection +// to either the HTTP reverse proxy or a direct TCP relay. +type Router struct { + logger *log.Logger + // httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter. + httpCh chan net.Conn + httpListener *chanListener + mu sync.RWMutex + routes map[SNIHost][]Route + fallback *Route + draining bool + dialResolve DialResolver + activeConns sync.WaitGroup + activeRelays sync.WaitGroup + relaySem chan struct{} + drainDone chan struct{} + observer RelayObserver + accessLog l4Logger + // svcCtxs tracks a context per service ID. All relay goroutines for a + // service derive from its context; canceling it kills them immediately. + svcCtxs map[types.ServiceID]context.Context + svcCancels map[types.ServiceID]context.CancelFunc +} + +// NewRouter creates a new SNI-based connection router. +func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router { + httpCh := make(chan net.Conn, httpChannelBuffer) + return &Router{ + logger: logger, + httpCh: httpCh, + httpListener: newChanListener(httpCh, addr), + routes: make(map[SNIHost][]Route), + dialResolve: dialResolve, + relaySem: make(chan struct{}, DefaultMaxRelayConns), + svcCtxs: make(map[types.ServiceID]context.Context), + svcCancels: make(map[types.ServiceID]context.CancelFunc), + } +} + +// NewPortRouter creates a Router for a dedicated port without an HTTP +// channel. Connections that don't match any SNI route fall through to +// the fallback relay (if set) or are closed. +func NewPortRouter(logger *log.Logger, dialResolve DialResolver) *Router { + return &Router{ + logger: logger, + routes: make(map[SNIHost][]Route), + dialResolve: dialResolve, + relaySem: make(chan struct{}, DefaultMaxRelayConns), + svcCtxs: make(map[types.ServiceID]context.Context), + svcCancels: make(map[types.ServiceID]context.CancelFunc), + } +} + +// HTTPListener returns a net.Listener that yields connections routed +// to the HTTP handler. Use this with http.Server.ServeTLS. +func (r *Router) HTTPListener() net.Listener { + return r.httpListener +} + +// AddRoute registers an SNI route. Multiple routes for the same host are +// stored and resolved by priority at lookup time (HTTP > TCP). +// Empty host is ignored to prevent conflicts with ECH/ESNI fallback. +func (r *Router) AddRoute(host SNIHost, route Route) { + if host == "" { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + routes := r.routes[host] + for i, existing := range routes { + if existing.ServiceID == route.ServiceID { + r.cancelServiceLocked(route.ServiceID) + routes[i] = route + return + } + } + r.routes[host] = append(routes, route) +} + +// RemoveRoute removes the route for the given host and service ID. +// Active relay connections for the service are closed immediately. +// If other routes remain for the host, they are preserved. +func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + + r.routes[host] = slices.DeleteFunc(r.routes[host], func(route Route) bool { + return route.ServiceID == svcID + }) + if len(r.routes[host]) == 0 { + delete(r.routes, host) + } + r.cancelServiceLocked(svcID) +} + +// SetFallback registers a catch-all route for connections that don't +// match any SNI route. On a port router this handles plain TCP relay; +// on the main router it takes priority over the HTTP channel. +func (r *Router) SetFallback(route Route) { + r.mu.Lock() + defer r.mu.Unlock() + r.fallback = &route +} + +// RemoveFallback clears the catch-all fallback route and closes any +// active relay connections for the given service. +func (r *Router) RemoveFallback(svcID types.ServiceID) { + r.mu.Lock() + defer r.mu.Unlock() + r.fallback = nil + r.cancelServiceLocked(svcID) +} + +// SetObserver sets the relay lifecycle observer. Must be called before Serve. +func (r *Router) SetObserver(obs RelayObserver) { + r.mu.Lock() + defer r.mu.Unlock() + r.observer = obs +} + +// SetAccessLogger sets the L4 access logger. Must be called before Serve. +func (r *Router) SetAccessLogger(l l4Logger) { + r.mu.Lock() + defer r.mu.Unlock() + r.accessLog = l +} + +// getObserver returns the current relay observer under the read lock. +func (r *Router) getObserver() RelayObserver { + r.mu.RLock() + defer r.mu.RUnlock() + return r.observer +} + +// IsEmpty returns true when the router has no SNI routes and no fallback. +func (r *Router) IsEmpty() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.routes) == 0 && r.fallback == nil +} + +// Serve accepts connections from ln and routes them based on SNI. +// It blocks until ctx is canceled or ln is closed, then drains +// active relay connections up to DefaultDrainTimeout. +func (r *Router) Serve(ctx context.Context, ln net.Listener) error { + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + _ = ln.Close() + if r.httpListener != nil { + r.httpListener.Close() + } + case <-done: + } + }() + + for { + conn, err := ln.Accept() + if err != nil { + if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + if ok := r.Drain(DefaultDrainTimeout); !ok { + r.logger.Warn("timed out waiting for connections to drain") + } + return nil + } + r.logger.Debugf("SNI router accept: %v", err) + continue + } + r.activeConns.Add(1) + go func() { + defer r.activeConns.Done() + r.handleConn(ctx, conn) + }() + } +} + +// handleConn peeks at the TLS ClientHello and routes the connection. +func (r *Router) handleConn(ctx context.Context, conn net.Conn) { + // Fast path: when no SNI routes and no HTTP channel exist (pure TCP + // fallback port), skip the TLS peek entirely to avoid read errors on + // non-TLS connections and reduce latency. + if r.isFallbackOnly() { + r.handleUnmatched(ctx, conn) + return + } + + if err := conn.SetReadDeadline(time.Now().Add(sniPeekTimeout)); err != nil { + r.logger.Debugf("set SNI peek deadline: %v", err) + _ = conn.Close() + return + } + + sni, wrapped, err := PeekClientHello(conn) + if err != nil { + r.logger.Debugf("SNI peek: %v", err) + if wrapped != nil { + r.handleUnmatched(ctx, wrapped) + } else { + _ = conn.Close() + } + return + } + + if err := wrapped.SetReadDeadline(time.Time{}); err != nil { + r.logger.Debugf("clear SNI peek deadline: %v", err) + _ = wrapped.Close() + return + } + + host := SNIHost(sni) + route, ok := r.lookupRoute(host) + if !ok { + r.handleUnmatched(ctx, wrapped) + return + } + + if route.Type == RouteHTTP { + r.sendToHTTP(wrapped) + return + } + + if err := r.relayTCP(ctx, wrapped, host, route); err != nil { + r.logger.WithFields(log.Fields{ + "sni": host, + "service_id": route.ServiceID, + "target": route.Target, + }).Warnf("TCP relay: %v", err) + _ = wrapped.Close() + } +} + +// isFallbackOnly returns true when the router has no SNI routes and no HTTP +// channel, meaning all connections should go directly to the fallback relay. +func (r *Router) isFallbackOnly() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.routes) == 0 && r.httpCh == nil +} + +// handleUnmatched routes a connection that didn't match any SNI route. +// This includes ECH/ESNI connections where the cleartext SNI is empty. +// It tries the fallback relay first, then the HTTP channel, and closes +// the connection if neither is available. +func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) { + r.mu.RLock() + fb := r.fallback + r.mu.RUnlock() + + if fb != nil { + if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil { + r.logger.WithFields(log.Fields{ + "service_id": fb.ServiceID, + "target": fb.Target, + }).Warnf("TCP relay (fallback): %v", err) + _ = conn.Close() + } + return + } + r.sendToHTTP(conn) +} + +// lookupRoute returns the highest-priority route for the given SNI host. +// HTTP routes take precedence over TCP routes. +func (r *Router) lookupRoute(host SNIHost) (Route, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + routes, ok := r.routes[host] + if !ok || len(routes) == 0 { + return Route{}, false + } + best := routes[0] + for _, route := range routes[1:] { + if route.Type < best.Type { + best = route + } + } + return best, true +} + +// sendToHTTP feeds the connection to the HTTP handler via the channel. +// If no HTTP channel is configured (port router), the router is +// draining, or the channel is full, the connection is closed. +func (r *Router) sendToHTTP(conn net.Conn) { + if r.httpCh == nil { + _ = conn.Close() + return + } + + r.mu.RLock() + draining := r.draining + r.mu.RUnlock() + + if draining { + _ = conn.Close() + return + } + + select { + case r.httpCh <- conn: + default: + r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr()) + _ = conn.Close() + } +} + +// Drain prevents new relay connections from starting and waits for all +// in-flight connection handlers and active relays to finish, up to the +// given timeout. Returns true if all completed, false on timeout. +func (r *Router) Drain(timeout time.Duration) bool { + r.mu.Lock() + r.draining = true + if r.drainDone == nil { + done := make(chan struct{}) + go func() { + r.activeConns.Wait() + r.activeRelays.Wait() + close(done) + }() + r.drainDone = done + } + done := r.drainDone + r.mu.Unlock() + + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +// cancelServiceLocked cancels and removes the context for the given service, +// closing all its active relay connections. Must be called with mu held. +func (r *Router) cancelServiceLocked(svcID types.ServiceID) { + if cancel, ok := r.svcCancels[svcID]; ok { + cancel() + delete(r.svcCtxs, svcID) + delete(r.svcCancels, svcID) + } +} + +// relayTCP sets up and runs a bidirectional TCP relay. +// The caller owns conn and must close it if this method returns an error. +// On success (nil error), both conn and backend are closed by the relay. +func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error { + svcCtx, err := r.acquireRelay(ctx, route) + if err != nil { + return err + } + defer func() { + <-r.relaySem + r.activeRelays.Done() + }() + + backend, err := r.dialBackend(svcCtx, route) + if err != nil { + obs := r.getObserver() + if obs != nil { + obs.TCPRelayDialError(route.AccountID) + } + return err + } + + if route.ProxyProtocol { + if err := writeProxyProtoV2(conn, backend); err != nil { + _ = backend.Close() + return fmt.Errorf("write PROXY protocol header: %w", err) + } + } + + obs := r.getObserver() + if obs != nil { + obs.TCPRelayStarted(route.AccountID) + } + + entry := r.logger.WithFields(log.Fields{ + "sni": sni, + "service_id": route.ServiceID, + "target": route.Target, + }) + entry.Debug("TCP relay started") + + start := time.Now() + s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout) + elapsed := time.Since(start) + + if obs != nil { + obs.TCPRelayEnded(route.AccountID, elapsed, s2d, d2s) + } + entry.Debugf("TCP relay ended (client→backend: %d bytes, backend→client: %d bytes)", s2d, d2s) + + r.logL4Entry(route, conn, elapsed, s2d, d2s) + return nil +} + +// acquireRelay checks draining state, increments activeRelays, and acquires +// a semaphore slot. Returns the per-service context on success. +// The caller must release the semaphore and call activeRelays.Done() when done. +func (r *Router) acquireRelay(ctx context.Context, route Route) (context.Context, error) { + r.mu.Lock() + if r.draining { + r.mu.Unlock() + return nil, errors.New("router is draining") + } + r.activeRelays.Add(1) + svcCtx := r.getOrCreateServiceCtxLocked(ctx, route.ServiceID) + r.mu.Unlock() + + select { + case r.relaySem <- struct{}{}: + return svcCtx, nil + default: + r.activeRelays.Done() + obs := r.getObserver() + if obs != nil { + obs.TCPRelayRejected(route.AccountID) + } + return nil, errors.New("TCP relay connection limit reached") + } +} + +// dialBackend resolves the dialer for the route's account and dials the backend. +func (r *Router) dialBackend(svcCtx context.Context, route Route) (net.Conn, error) { + dialFn, err := r.dialResolve(route.AccountID) + if err != nil { + return nil, fmt.Errorf("resolve dialer: %w", err) + } + + dialTimeout := route.DialTimeout + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + dialCtx, dialCancel := context.WithTimeout(svcCtx, dialTimeout) + backend, err := dialFn(dialCtx, "tcp", route.Target) + dialCancel() + if err != nil { + return nil, fmt.Errorf("dial backend %s: %w", route.Target, err) + } + return backend, nil +} + +// logL4Entry sends a TCP relay access log entry if an access logger is configured. +func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, bytesUp, bytesDown int64) { + r.mu.RLock() + al := r.accessLog + r.mu.RUnlock() + + if al == nil { + return + } + + var sourceIP netip.Addr + if remote := conn.RemoteAddr(); remote != nil { + if ap, err := netip.ParseAddrPort(remote.String()); err == nil { + sourceIP = ap.Addr().Unmap() + } + } + + al.LogL4(accesslog.L4Entry{ + AccountID: route.AccountID, + ServiceID: route.ServiceID, + Protocol: route.Protocol, + Host: route.Domain, + SourceIP: sourceIP, + DurationMs: duration.Milliseconds(), + BytesUpload: bytesUp, + BytesDownload: bytesDown, + }) +} + +// getOrCreateServiceCtxLocked returns the context for a service, creating one +// if it doesn't exist yet. The context is a child of the server context. +// Must be called with mu held. +func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types.ServiceID) context.Context { + if ctx, ok := r.svcCtxs[svcID]; ok { + return ctx + } + ctx, cancel := context.WithCancel(parent) + r.svcCtxs[svcID] = ctx + r.svcCancels[svcID] = cancel + return ctx +} diff --git a/proxy/internal/tcp/router_test.go b/proxy/internal/tcp/router_test.go new file mode 100644 index 000000000..0e2cfe3e1 --- /dev/null +++ b/proxy/internal/tcp/router_test.go @@ -0,0 +1,1670 @@ +package tcp + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "math/big" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRouter_HTTPRouting(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr) + router.AddRoute("example.com", Route{Type: RouteHTTP}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Dial in a goroutine. The TLS handshake will block since nothing + // completes it on the HTTP side, but we only care about routing. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + // Send a TLS ClientHello manually. + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + tlsConn.Close() + }() + + // Verify the connection was routed to the HTTP channel. + select { + case conn := <-router.httpCh: + assert.NotNil(t, conn) + conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("no connection received on HTTP channel") + } +} + +func TestRouter_TCPRouting(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + // Set up a TLS backend that the relay will connect to. + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + backendAddr := backendLn.Addr().String() + + // Accept one connection on the backend, echo data back. + backendReady := make(chan struct{}) + go func() { + close(backendReady) + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + <-backendReady + + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendAddr, + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Connect as a TLS client; the proxy should passthrough to the backend. + clientConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer clientConn.Close() + + testData := []byte("hello through TCP passthrough") + _, err = clientConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := clientConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed data through TCP passthrough") +} + +func TestRouter_UnknownSNIGoesToHTTP(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr) + // No routes registered. + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "unknown.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + tlsConn.Close() + }() + + select { + case conn := <-router.httpCh: + assert.NotNil(t, conn) + conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("unknown SNI should be routed to HTTP") + } +} + +// TestRouter_NonTLSConnectionDropped verifies that a non-TLS connection +// on the shared port is closed by the router (SNI peek fails to find a +// valid ClientHello, so there is no route match). +func TestRouter_NonTLSConnectionDropped(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + // Register a TLS passthrough route. Non-TLS should NOT match. + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: "127.0.0.1:9999", + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Send plain HTTP (non-TLS) data. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: tcp.example.com\r\n\r\n")) + + // Non-TLS traffic on a port with RouteTCP goes to the HTTP channel + // because there's no valid SNI to match. Verify it reaches HTTP. + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "non-TLS connection should fall through to HTTP") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("non-TLS connection was not routed to HTTP") + } +} + +// TestRouter_TLSAndHTTPCoexist verifies that a shared port with both HTTP +// and TLS passthrough routes correctly demuxes based on the SNI hostname. +func TestRouter_TLSAndHTTPCoexist(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + // Backend echoes data. + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + router := NewRouter(logger, dialResolve, addr) + // HTTP route. + router.AddRoute("app.example.com", Route{Type: RouteHTTP}) + // TLS passthrough route. + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // 1. TLS connection with SNI "tcp.example.com" → TLS passthrough. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + testData := []byte("passthrough data") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "TLS passthrough should relay data") + tlsConn.Close() + + // 2. TLS connection with SNI "app.example.com" → HTTP handler. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + c := tls.Client(conn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = c.Handshake() + c.Close() + }() + + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "HTTP SNI should go to HTTP handler") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("HTTP-route connection was not delivered to HTTP handler") + } +} + +func TestRouter_AddRemoveRoute(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"}) + router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.1:5432"}) + + route, ok := router.lookupRoute("a.example.com") + assert.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type) + + route, ok = router.lookupRoute("b.example.com") + assert.True(t, ok) + assert.Equal(t, RouteTCP, route.Type) + + router.RemoveRoute("a.example.com", "svc-a") + _, ok = router.lookupRoute("a.example.com") + assert.False(t, ok) +} + +func TestChanListener_AcceptAndClose(t *testing.T) { + ch := make(chan net.Conn, 1) + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + ln := newChanListener(ch, addr) + + assert.Equal(t, addr, ln.Addr()) + + // Send a connection. + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + ch <- serverConn + + conn, err := ln.Accept() + require.NoError(t, err) + assert.Equal(t, serverConn, conn) + + // Close should cause Accept to return error. + require.NoError(t, ln.Close()) + // Double close should be safe. + require.NoError(t, ln.Close()) + + _, err = ln.Accept() + assert.ErrorIs(t, err, net.ErrClosed) +} + +func TestRouter_HTTPPrecedenceGuard(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, nil, addr) + + host := SNIHost("app.example.com") + + t.Run("http takes precedence over tcp at lookup", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type, "HTTP route must take precedence over TCP") + assert.Equal(t, types.ServiceID("svc-http"), route.ServiceID) + + router.RemoveRoute(host, "svc-http") + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("tcp becomes active when http is removed", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + + router.RemoveRoute(host, "svc-http") + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteTCP, route.Type, "TCP should take over after HTTP removal") + assert.Equal(t, types.ServiceID("svc-tcp"), route.ServiceID) + + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("order of add does not matter", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"}) + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, RouteHTTP, route.Type, "HTTP takes precedence regardless of add order") + + router.RemoveRoute(host, "svc-http") + router.RemoveRoute(host, "svc-tcp") + }) + + t.Run("same service id updates in place", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.1:443"}) + router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.2:443"}) + + route, ok := router.lookupRoute(host) + require.True(t, ok) + assert.Equal(t, "10.0.0.2:443", route.Target, "route should be updated in place") + + router.RemoveRoute(host, "svc-1") + _, ok = router.lookupRoute(host) + assert.False(t, ok) + }) + + t.Run("double remove is safe", func(t *testing.T) { + router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-1"}) + router.RemoveRoute(host, "svc-1") + router.RemoveRoute(host, "svc-1") + + _, ok := router.lookupRoute(host) + assert.False(t, ok, "route should be gone after removal") + }) + + t.Run("remove does not affect other hosts", func(t *testing.T) { + router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"}) + router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.2:22"}) + + router.RemoveRoute("a.example.com", "svc-a") + + _, ok := router.lookupRoute(SNIHost("a.example.com")) + assert.False(t, ok) + + route, ok := router.lookupRoute(SNIHost("b.example.com")) + require.True(t, ok) + assert.Equal(t, RouteTCP, route.Type, "removing one host must not affect another") + + router.RemoveRoute("b.example.com", "svc-b") + }) +} + +func TestRouter_SetRemoveFallback(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + assert.True(t, router.IsEmpty(), "new port router should be empty") + + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb", Target: "10.0.0.1:5432"}) + assert.False(t, router.IsEmpty(), "router with fallback should not be empty") + + router.AddRoute("a.example.com", Route{Type: RouteTCP, ServiceID: "svc-a", Target: "10.0.0.2:443"}) + assert.False(t, router.IsEmpty()) + + router.RemoveFallback("svc-fb") + assert.False(t, router.IsEmpty(), "router with SNI route should not be empty") + + router.RemoveRoute("a.example.com", "svc-a") + assert.True(t, router.IsEmpty(), "router with no routes and no fallback should be empty") +} + +func TestPortRouter_FallbackRelaysData(t *testing.T) { + // Backend echo server. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Plain TCP (non-TLS) connection should be relayed via fallback. + // Use exactly 5 bytes. PeekClientHello reads 5 bytes as the TLS + // header, so a single 5-byte write lands as one chunk at the backend. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + testData := []byte("hello") + _, err = conn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed data through fallback relay") +} + +func TestPortRouter_FallbackOnUnknownSNI(t *testing.T) { + // Backend TLS echo server. + backendCert := generateSelfSignedCert(t) + backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{backendCert}, + }) + require.NoError(t, err) + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + // Only a fallback, no SNI route for "unknown.example.com". + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "test-service", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // TLS with unknown SNI → fallback relay to TLS backend. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer tlsConn.Close() + + testData := []byte("hello through fallback TLS") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "unknown SNI should relay through fallback") +} + +func TestPortRouter_SNIWinsOverFallback(t *testing.T) { + // Two backend echo servers: one for SNI match, one for fallback. + sniBacked := startEchoTLS(t) + fbBacked := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "sni-service", + Target: sniBacked.Addr().String(), + }) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "fb-service", + Target: fbBacked.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // TLS with matching SNI should go to SNI backend, not fallback. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + defer tlsConn.Close() + + testData := []byte("SNI route data") + _, err = tlsConn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "SNI match should use SNI route, not fallback") +} + +func TestPortRouter_NoFallbackNoHTTP_Closes(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + _, _ = conn.Write([]byte("hello")) + + // Connection should be closed by the router (no fallback, no HTTP). + buf := make([]byte, 1) + _ = conn.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err = conn.Read(buf) + assert.Error(t, err, "connection should be closed when no fallback and no HTTP channel") +} + +func TestRouter_FallbackAndHTTPCoexist(t *testing.T) { + // Fallback backend echo server (plain TCP). + fbBackend, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer fbBackend.Close() + + go func() { + conn, err := fbBackend.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, dialResolve, addr) + + // HTTP route for known SNI. + router.AddRoute("app.example.com", Route{Type: RouteHTTP}) + // Fallback for non-TLS / unknown SNI. + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "test-account", + ServiceID: "fb-service", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // 1. TLS with known HTTP SNI → should go to HTTP channel. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + c := tls.Client(conn, &tls.Config{ + ServerName: "app.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = c.Handshake() + c.Close() + }() + + select { + case httpConn := <-router.httpCh: + assert.NotNil(t, httpConn, "known HTTP SNI should go to HTTP channel") + httpConn.Close() + case <-time.After(5 * time.Second): + t.Fatal("HTTP-route connection was not delivered to HTTP handler") + } + + // 2. Plain TCP (non-TLS) → should go to fallback, not HTTP. + // Use exactly 5 bytes to match PeekClientHello header size. + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn.Close() + + testData := []byte("plain") + _, err = conn.Write(testData) + require.NoError(t, err) + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "non-TLS should be relayed via fallback, not HTTP") +} + +// startEchoTLS starts a TLS echo server and returns the listener. +func startEchoTLS(t *testing.T) net.Listener { + t.Helper() + + cert := generateSelfSignedCert(t) + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + if _, err := conn.Write(buf[:n]); err != nil { + return + } + } + }() + + return ln +} + +func generateSelfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"tcp.example.com"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + return tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: key, + } +} + +func TestRouter_DrainWaitsForRelays(t *testing.T) { + logger := log.StandardLogger() + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + // Accept connections: echo first message, then hold open until told to close. + closeBackend := make(chan struct{}) + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + <-closeBackend + }(conn) + } + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + serveDone := make(chan struct{}) + go func() { + _ = router.Serve(ctx, ln) + close(serveDone) + }() + + // Open a relay connection (non-TLS, hits fallback). + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + _, _ = conn.Write([]byte("hello")) + + // Wait for the echo to confirm the relay is fully established. + buf := make([]byte, 16) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + _ = conn.SetReadDeadline(time.Time{}) + + // Drain with a short timeout should fail because the relay is still active. + assert.False(t, router.Drain(50*time.Millisecond), "drain should timeout with active relay") + + // Close backend connections so relays finish. + close(closeBackend) + _ = conn.Close() + + // Drain should now complete quickly. + assert.True(t, router.Drain(2*time.Second), "drain should succeed after relays end") + + cancel() + <-serveDone +} + +func TestRouter_DrainEmptyReturnsImmediately(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + start := time.Now() + ok := router.Drain(5 * time.Second) + elapsed := time.Since(start) + + assert.True(t, ok) + assert.Less(t, elapsed, 100*time.Millisecond, "drain with no relays should return immediately") +} + +// TestRemoveRoute_KillsActiveRelays verifies that removing a route +// immediately kills active relay connections for that service. +func TestRemoveRoute_KillsActiveRelays(t *testing.T) { + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + // Backend echoes first message, then holds connection open. + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + // Hold the connection open. + for { + if _, err := c.Read(buf); err != nil { + return + } + } + }(conn) + } + }() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + ServiceID: "svc-1", + Target: backendLn.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Establish a relay connection. + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer conn.Close() + _, err = conn.Write([]byte("hello")) + require.NoError(t, err) + + // Wait for echo to confirm relay is established. + buf := make([]byte, 16) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + _ = conn.SetReadDeadline(time.Time{}) + + // Remove the fallback: should kill the active relay. + router.RemoveFallback("svc-1") + + // The client connection should see an error (server closed). + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(buf) + assert.Error(t, err, "connection should be killed after service removal") +} + +// TestRemoveRoute_KillsSNIRelays verifies that removing an SNI route +// kills its active relays without affecting other services. +func TestRemoveRoute_KillsSNIRelays(t *testing.T) { + backend := startEchoTLS(t) + defer backend.Close() + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + router := NewRouter(logger, dialResolve, addr) + router.AddRoute("tls.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-tls", + Target: backend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Establish a TLS relay. + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + defer tlsConn.Close() + + _, err = tlsConn.Write([]byte("ping")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "ping", string(buf[:n])) + + // Remove the route: active relay should die. + router.RemoveRoute("tls.example.com", "svc-tls") + + _ = tlsConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = tlsConn.Read(buf) + assert.Error(t, err, "TLS relay should be killed after route removal") +} + +// TestPortRouter_SNIAndTCPFallbackCoexist verifies that a single port can +// serve both SNI-routed TLS passthrough and plain TCP fallback simultaneously. +func TestPortRouter_SNIAndTCPFallbackCoexist(t *testing.T) { + sniBackend := startEchoTLS(t) + fbBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + + // SNI route for a specific domain. + router.AddRoute("tcp.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-sni", + Target: sniBackend.Addr().String(), + }) + // TCP fallback for everything else. + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "acct-2", + ServiceID: "svc-fb", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // 1. TLS with matching SNI → goes to SNI backend. + tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "tcp.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + _, err = tlsConn.Write([]byte("sni-data")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "sni-data", string(buf[:n]), "SNI match → SNI backend") + tlsConn.Close() + + // 2. Plain TCP (no TLS) → goes to fallback. + tcpConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + + _, err = tcpConn.Write([]byte("plain")) + require.NoError(t, err) + n, err = tcpConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "plain", string(buf[:n]), "plain TCP → fallback backend") + tcpConn.Close() + + // 3. TLS with unknown SNI → also goes to fallback. + unknownBackend := startEchoTLS(t) + router.SetFallback(Route{ + Type: RouteTCP, + AccountID: "acct-2", + ServiceID: "svc-fb", + Target: unknownBackend.Addr().String(), + }) + + unknownTLS, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "unknown.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + + _, err = unknownTLS.Write([]byte("unknown-sni")) + require.NoError(t, err) + n, err = unknownTLS.Read(buf) + require.NoError(t, err) + assert.Equal(t, "unknown-sni", string(buf[:n]), "unknown SNI → fallback backend") + unknownTLS.Close() +} + +// TestPortRouter_UpdateRouteSwapsSNI verifies that updating a route +// (remove + add with different target) correctly routes to the new backend. +func TestPortRouter_UpdateRouteSwapsSNI(t *testing.T) { + backend1 := startEchoTLS(t) + backend2 := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Initial route → backend1. + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: backend1.Addr().String(), + }) + + conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn1.Write([]byte("v1")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "v1", string(buf[:n])) + conn1.Close() + + // Update: remove old route, add new → backend2. + router.RemoveRoute("db.example.com", "svc-db") + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: backend2.Addr().String(), + }) + + conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn2.Write([]byte("v2")) + require.NoError(t, err) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "v2", string(buf[:n])) + conn2.Close() +} + +// TestPortRouter_RemoveSNIFallsThrough verifies that after removing an +// SNI route, connections for that domain fall through to the fallback. +func TestPortRouter_RemoveSNIFallsThrough(t *testing.T) { + sniBackend := startEchoTLS(t) + fbBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.AddRoute("db.example.com", Route{ + Type: RouteTCP, + ServiceID: "svc-db", + Target: sniBackend.Addr().String(), + }) + router.SetFallback(Route{ + Type: RouteTCP, + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Before removal: SNI matches → sniBackend. + conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn1.Write([]byte("before")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "before", string(buf[:n])) + conn1.Close() + + // Remove SNI route. Should fall through to fallback. + router.RemoveRoute("db.example.com", "svc-db") + + conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + ServerName: "db.example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + require.NoError(t, err) + _, err = conn2.Write([]byte("after")) + require.NoError(t, err) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "after", string(buf[:n]), "after removal, should reach fallback") + conn2.Close() +} + +// TestPortRouter_RemoveFallbackCloses verifies that after removing the +// fallback, non-matching connections are closed. +func TestPortRouter_RemoveFallbackCloses(t *testing.T) { + fbBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return func(_ context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + }, nil + } + + logger := log.StandardLogger() + router := NewPortRouter(logger, dialResolve) + router.SetFallback(Route{ + Type: RouteTCP, + ServiceID: "svc-fb", + Target: fbBackend.Addr().String(), + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // With fallback: plain TCP works. + conn1, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _, err = conn1.Write([]byte("hello")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello", string(buf[:n])) + conn1.Close() + + // Remove fallback. + router.RemoveFallback("svc-fb") + + // Without fallback on a port router (no HTTP channel): connection should be closed. + conn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn2.Close() + _, _ = conn2.Write([]byte("bye")) + _ = conn2.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err = conn2.Read(buf) + assert.Error(t, err, "without fallback, connection should be closed") +} + +// TestPortRouter_HTTPToTLSTransition verifies that switching a service from +// HTTP-only to TLS-only via remove+add doesn't orphan the old HTTP route. +func TestPortRouter_HTTPToTLSTransition(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + tlsBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewRouter(logger, dialResolve, addr) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Phase 1: HTTP-only. SNI connections go to HTTP channel. + router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"}) + + httpConn := router.HTTPListener() + connDone := make(chan struct{}) + go func() { + defer close(connDone) + c, err := httpConn.Accept() + if err == nil { + c.Close() + } + }() + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + if err == nil { + tlsConn.Close() + } + select { + case <-connDone: + case <-time.After(2 * time.Second): + t.Fatal("HTTP listener did not receive connection for HTTP-only route") + } + + // Phase 2: Simulate update to TLS-only (removeMapping + addMapping). + router.RemoveRoute("app.example.com", "svc-1") + router.AddRoute("app.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-1", + Target: tlsBackend.Addr().String(), + }) + + tlsConn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err, "TLS connection should succeed after HTTP→TLS transition") + defer tlsConn2.Close() + + _, err = tlsConn2.Write([]byte("hello-tls")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "hello-tls", string(buf[:n]), "data should relay to TLS backend") +} + +// TestPortRouter_TLSToHTTPTransition verifies that switching a service from +// TLS-only to HTTP-only via remove+add doesn't orphan the old TLS route. +func TestPortRouter_TLSToHTTPTransition(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + tlsBackend := startEchoTLS(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewRouter(logger, dialResolve, addr) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Phase 1: TLS-only. Route relays to backend. + router.AddRoute("app.example.com", Route{ + Type: RouteTCP, + AccountID: "acct-1", + ServiceID: "svc-1", + Target: tlsBackend.Addr().String(), + }) + + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err, "TLS relay should work before transition") + _, err = tlsConn.Write([]byte("tls-data")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "tls-data", string(buf[:n])) + tlsConn.Close() + + // Phase 2: Simulate update to HTTP-only (removeMapping + addMapping). + router.RemoveRoute("app.example.com", "svc-1") + router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"}) + + // TLS connection should now go to the HTTP listener, NOT to the old TLS backend. + httpConn := router.HTTPListener() + connDone := make(chan struct{}) + go func() { + defer close(connDone) + c, err := httpConn.Accept() + if err == nil { + c.Close() + } + }() + tlsConn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true}, + ) + if err == nil { + tlsConn2.Close() + } + select { + case <-connDone: + case <-time.After(2 * time.Second): + t.Fatal("HTTP listener should receive connection after TLS→HTTP transition") + } +} + +// TestPortRouter_MultiDomainSamePort verifies that two TLS services sharing +// the same port router are independently routable and removable. +func TestPortRouter_MultiDomainSamePort(t *testing.T) { + logger := log.StandardLogger() + backend1 := startEchoTLSMulti(t) + backend2 := startEchoTLSMulti(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + router.AddRoute("svc1.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: backend1.Addr().String()}) + router.AddRoute("svc2.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-2", Target: backend2.Addr().String()}) + assert.False(t, router.IsEmpty()) + + // Both domains route independently. + for _, tc := range []struct { + sni string + data string + }{ + {"svc1.example.com", "hello-svc1"}, + {"svc2.example.com", "hello-svc2"}, + } { + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: tc.sni, InsecureSkipVerify: true}, + ) + require.NoError(t, err, "dial %s", tc.sni) + _, err = conn.Write([]byte(tc.data)) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, tc.data, string(buf[:n])) + conn.Close() + } + + // Remove svc1. Router should NOT be empty (svc2 still present). + router.RemoveRoute("svc1.example.com", "svc-1") + assert.False(t, router.IsEmpty(), "router should not be empty with one route remaining") + + // svc2 still works. + conn2, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "svc2.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + _, err = conn2.Write([]byte("still-alive")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := conn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "still-alive", string(buf[:n])) + conn2.Close() + + // Remove svc2. Router is now empty. + router.RemoveRoute("svc2.example.com", "svc-2") + assert.True(t, router.IsEmpty(), "router should be empty after removing all routes") +} + +// TestPortRouter_SNIAndFallbackLifecycle verifies the full lifecycle of SNI +// routes and TCP fallback coexisting on the same port router, including the +// ordering of add/remove operations. +func TestPortRouter_SNIAndFallbackLifecycle(t *testing.T) { + logger := log.StandardLogger() + sniBackend := startEchoTLS(t) + fallbackBackend := startEchoPlain(t) + + dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + } + + router := NewPortRouter(logger, dialResolve) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Step 1: Add fallback first (port mapping), then SNI route (TLS service). + router.SetFallback(Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "pm-1", Target: fallbackBackend.Addr().String()}) + router.AddRoute("tls.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: sniBackend.Addr().String()}) + assert.False(t, router.IsEmpty()) + + // SNI traffic goes to TLS backend. + tlsConn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 2 * time.Second}, + "tcp", ln.Addr().String(), + &tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true}, + ) + require.NoError(t, err) + _, err = tlsConn.Write([]byte("sni-traffic")) + require.NoError(t, err) + buf := make([]byte, 1024) + n, err := tlsConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "sni-traffic", string(buf[:n])) + tlsConn.Close() + + // Plain TCP goes to fallback. + plainConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _, err = plainConn.Write([]byte("plain")) + require.NoError(t, err) + n, err = plainConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "plain", string(buf[:n])) + plainConn.Close() + + // Step 2: Remove SNI route. Fallback still works, router not empty. + router.RemoveRoute("tls.example.com", "svc-1") + assert.False(t, router.IsEmpty(), "fallback still present") + + plainConn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + // Must send >= 5 bytes so the SNI peek completes immediately + // without waiting for the 5-second peek timeout. + _, err = plainConn2.Write([]byte("after")) + require.NoError(t, err) + n, err = plainConn2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "after", string(buf[:n])) + plainConn2.Close() + + // Step 3: Remove fallback. Router is now empty. + router.RemoveFallback("pm-1") + assert.True(t, router.IsEmpty()) +} + +// TestPortRouter_IsEmptyTransitions verifies IsEmpty reflects correct state +// through all add/remove operations. +func TestPortRouter_IsEmptyTransitions(t *testing.T) { + logger := log.StandardLogger() + router := NewPortRouter(logger, nil) + + assert.True(t, router.IsEmpty(), "new router") + + router.AddRoute("a.com", Route{Type: RouteTCP, ServiceID: "svc-a"}) + assert.False(t, router.IsEmpty(), "after adding route") + + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb1"}) + assert.False(t, router.IsEmpty(), "route + fallback") + + router.RemoveRoute("a.com", "svc-a") + assert.False(t, router.IsEmpty(), "fallback only") + + router.RemoveFallback("svc-fb1") + assert.True(t, router.IsEmpty(), "all removed") + + // Reverse order: fallback first, then route. + router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb2"}) + assert.False(t, router.IsEmpty()) + + router.AddRoute("b.com", Route{Type: RouteTCP, ServiceID: "svc-b"}) + assert.False(t, router.IsEmpty()) + + router.RemoveFallback("svc-fb2") + assert.False(t, router.IsEmpty(), "route still present") + + router.RemoveRoute("b.com", "svc-b") + assert.True(t, router.IsEmpty(), "fully empty again") +} + +// startEchoTLSMulti starts a TLS echo server that accepts multiple connections. +func startEchoTLSMulti(t *testing.T) net.Listener { + t.Helper() + + cert := generateSelfSignedCert(t) + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + }(conn) + } + }() + + return ln +} + +// startEchoPlain starts a plain TCP echo server that reads until newline +// or connection close, then echoes the received data. +func startEchoPlain(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + // Set a read deadline so we don't block forever waiting for more data. + _ = c.SetReadDeadline(time.Now().Add(2 * time.Second)) + buf := make([]byte, 1024) + n, _ := c.Read(buf) + _, _ = c.Write(buf[:n]) + }(conn) + } + }() + + return ln +} diff --git a/proxy/internal/tcp/snipeek.go b/proxy/internal/tcp/snipeek.go new file mode 100644 index 000000000..25ab8e5ef --- /dev/null +++ b/proxy/internal/tcp/snipeek.go @@ -0,0 +1,191 @@ +package tcp + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" +) + +const ( + // TLS record header is 5 bytes: ContentType(1) + Version(2) + Length(2). + tlsRecordHeaderLen = 5 + // TLS handshake type for ClientHello. + handshakeTypeClientHello = 1 + // TLS ContentType for handshake messages. + contentTypeHandshake = 22 + // SNI extension type (RFC 6066). + extensionServerName = 0 + // SNI host name type. + sniHostNameType = 0 + // maxClientHelloLen caps the ClientHello size we're willing to buffer. + maxClientHelloLen = 16384 + // maxSNILen is the maximum valid DNS hostname length per RFC 1035. + maxSNILen = 253 +) + +// PeekClientHello reads the TLS ClientHello from conn, extracts the SNI +// server name, and returns a wrapped connection that replays the peeked +// bytes transparently. If the data is not a valid TLS ClientHello or +// contains no SNI extension, sni is empty and err is nil. +// +// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the +// real server name is encrypted inside the encrypted_client_hello +// extension. This parser only reads the cleartext server_name extension +// (type 0x0000), so ECH connections return sni="" and are routed through +// the fallback path (or HTTP channel), which is the correct behavior +// for a transparent proxy that does not terminate TLS. +func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) { + // Read the 5-byte TLS record header into a small stack-friendly buffer. + var header [tlsRecordHeaderLen]byte + if _, err := io.ReadFull(conn, header[:]); err != nil { + return "", nil, fmt.Errorf("read TLS record header: %w", err) + } + + if header[0] != contentTypeHandshake { + return "", newPeekedConn(conn, header[:]), nil + } + + recordLen := int(binary.BigEndian.Uint16(header[3:5])) + if recordLen == 0 || recordLen > maxClientHelloLen { + return "", newPeekedConn(conn, header[:]), nil + } + + // Single allocation for header + payload. The peekedConn takes + // ownership of this buffer, so no further copies are needed. + buf := make([]byte, tlsRecordHeaderLen+recordLen) + copy(buf, header[:]) + + n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:]) + if err != nil { + return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err) + } + + sni = extractSNI(buf[tlsRecordHeaderLen:]) + return sni, newPeekedConn(conn, buf), nil +} + +// extractSNI parses a TLS handshake payload to find the SNI extension. +// Returns empty string if the payload is not a ClientHello or has no SNI. +func extractSNI(payload []byte) string { + if len(payload) < 4 { + return "" + } + + if payload[0] != handshakeTypeClientHello { + return "" + } + + // Handshake length (3 bytes, big-endian). + handshakeLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3]) + if handshakeLen > len(payload)-4 { + return "" + } + + return parseSNIFromClientHello(payload[4 : 4+handshakeLen]) +} + +// parseSNIFromClientHello walks the ClientHello message fields to reach +// the extensions block and extract the server_name extension value. +func parseSNIFromClientHello(msg []byte) string { + // ClientHello layout: + // ProtocolVersion(2) + Random(32) = 34 bytes minimum before session_id + if len(msg) < 34 { + return "" + } + + pos := 34 + + // Session ID (variable, 1 byte length prefix). + if pos >= len(msg) { + return "" + } + sessionIDLen := int(msg[pos]) + pos++ + pos += sessionIDLen + + // Cipher suites (variable, 2 byte length prefix). + if pos+2 > len(msg) { + return "" + } + cipherSuitesLen := int(binary.BigEndian.Uint16(msg[pos : pos+2])) + pos += 2 + cipherSuitesLen + + // Compression methods (variable, 1 byte length prefix). + if pos >= len(msg) { + return "" + } + compMethodsLen := int(msg[pos]) + pos++ + pos += compMethodsLen + + // Extensions (variable, 2 byte length prefix). + if pos+2 > len(msg) { + return "" + } + extensionsLen := int(binary.BigEndian.Uint16(msg[pos : pos+2])) + pos += 2 + + extensionsEnd := pos + extensionsLen + if extensionsEnd > len(msg) { + return "" + } + + return findSNIExtension(msg[pos:extensionsEnd]) +} + +// findSNIExtension iterates over TLS extensions and returns the host +// name from the server_name extension, if present. +func findSNIExtension(extensions []byte) string { + pos := 0 + for pos+4 <= len(extensions) { + extType := binary.BigEndian.Uint16(extensions[pos : pos+2]) + extLen := int(binary.BigEndian.Uint16(extensions[pos+2 : pos+4])) + pos += 4 + + if pos+extLen > len(extensions) { + return "" + } + + if extType == extensionServerName { + return parseSNIExtensionData(extensions[pos : pos+extLen]) + } + pos += extLen + } + return "" +} + +// parseSNIExtensionData parses the ServerNameList structure inside an +// SNI extension to extract the host name. +func parseSNIExtensionData(data []byte) string { + if len(data) < 2 { + return "" + } + listLen := int(binary.BigEndian.Uint16(data[0:2])) + if listLen > len(data)-2 { + return "" + } + + list := data[2 : 2+listLen] + pos := 0 + for pos+3 <= len(list) { + nameType := list[pos] + nameLen := int(binary.BigEndian.Uint16(list[pos+1 : pos+3])) + pos += 3 + + if pos+nameLen > len(list) { + return "" + } + + if nameType == sniHostNameType { + name := list[pos : pos+nameLen] + if nameLen > maxSNILen || bytes.ContainsRune(name, 0) { + return "" + } + return string(name) + } + pos += nameLen + } + return "" +} diff --git a/proxy/internal/tcp/snipeek_test.go b/proxy/internal/tcp/snipeek_test.go new file mode 100644 index 000000000..9afe6261d --- /dev/null +++ b/proxy/internal/tcp/snipeek_test.go @@ -0,0 +1,251 @@ +package tcp + +import ( + "crypto/tls" + "io" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPeekClientHello_ValidSNI(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + const expectedSNI = "example.com" + trailingData := []byte("trailing data after handshake") + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: expectedSNI, + InsecureSkipVerify: true, //nolint:gosec + }) + // The Handshake will send the ClientHello. It will fail because + // our server side isn't doing a real TLS handshake, but that's + // fine: we only need the ClientHello to be sent. + _ = tlsConn.Handshake() + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello") + assert.NotNil(t, wrapped, "wrapped connection should not be nil") + + // Verify the wrapped connection replays the peeked bytes. + // Read the first 5 bytes (TLS record header) to confirm replay. + buf := make([]byte, 5) + n, err := wrapped.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, byte(contentTypeHandshake), buf[0], "first byte should be TLS handshake content type") + + // Write trailing data from the client side and verify it arrives + // through the wrapped connection after the peeked bytes. + go func() { + _, _ = clientConn.Write(trailingData) + }() + + // Drain the rest of the peeked ClientHello first. + peekedRest := make([]byte, 16384) + _, _ = wrapped.Read(peekedRest) + + got := make([]byte, len(trailingData)) + n, err = io.ReadFull(wrapped, got) + require.NoError(t, err) + assert.Equal(t, trailingData, got[:n]) +} + +func TestPeekClientHello_MultipleSNIs(t *testing.T) { + tests := []struct { + name string + serverName string + expectedSNI string + }{ + {"simple domain", "example.com", "example.com"}, + {"subdomain", "sub.example.com", "sub.example.com"}, + {"deep subdomain", "a.b.c.example.com", "a.b.c.example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: tt.serverName, + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Equal(t, tt.expectedSNI, sni) + assert.NotNil(t, wrapped) + }) + } +} + +func TestPeekClientHello_NonTLSData(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // Send plain HTTP data (not TLS). + httpData := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + go func() { + _, _ = clientConn.Write(httpData) + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Empty(t, sni, "should return empty SNI for non-TLS data") + assert.NotNil(t, wrapped) + + // Verify the wrapped connection still provides the original data. + buf := make([]byte, len(httpData)) + n, err := io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, httpData, buf[:n], "wrapped connection should replay original data") +} + +func TestPeekClientHello_TruncatedHeader(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + // Write only 3 bytes then close, fewer than the 5-byte TLS header. + go func() { + _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01}) + clientConn.Close() + }() + + _, _, err := PeekClientHello(serverConn) + assert.Error(t, err, "should error on truncated header") +} + +func TestPeekClientHello_TruncatedPayload(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + // Write a valid TLS header claiming 100 bytes, but only send 10. + go func() { + header := []byte{0x16, 0x03, 0x01, 0x00, 0x64} // 100 bytes claimed + _, _ = clientConn.Write(header) + _, _ = clientConn.Write(make([]byte, 10)) + clientConn.Close() + }() + + _, _, err := PeekClientHello(serverConn) + assert.Error(t, err, "should error on truncated payload") +} + +func TestPeekClientHello_ZeroLengthRecord(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // TLS handshake header with zero-length payload. + go func() { + _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00}) + }() + + sni, wrapped, err := PeekClientHello(serverConn) + require.NoError(t, err) + assert.Empty(t, sni) + assert.NotNil(t, wrapped) +} + +func TestExtractSNI_InvalidPayload(t *testing.T) { + tests := []struct { + name string + payload []byte + }{ + {"nil", nil}, + {"empty", []byte{}}, + {"too short", []byte{0x01, 0x00}}, + {"wrong handshake type", []byte{0x02, 0x00, 0x00, 0x05, 0x03, 0x03, 0x00, 0x00, 0x00}}, + {"truncated client hello", []byte{0x01, 0x00, 0x00, 0x20}}, // claims 32 bytes but has none + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Empty(t, extractSNI(tt.payload)) + }) + } +} + +func TestPeekedConn_CloseWrite(t *testing.T) { + t.Run("delegates to underlying TCPConn", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + accepted := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err == nil { + accepted <- c + } + }() + + client, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer client.Close() + + server := <-accepted + defer server.Close() + + wrapped := newPeekedConn(server, []byte("peeked")) + + // CloseWrite should succeed on a real TCP connection. + err = wrapped.CloseWrite() + assert.NoError(t, err) + + // The client should see EOF on reads after CloseWrite. + buf := make([]byte, 1) + _, err = client.Read(buf) + assert.Equal(t, io.EOF, err, "client should see EOF after half-close") + }) + + t.Run("no-op on non-halfcloser", func(t *testing.T) { + // net.Pipe does not implement CloseWrite. + _, server := net.Pipe() + defer server.Close() + + wrapped := newPeekedConn(server, []byte("peeked")) + err := wrapped.CloseWrite() + assert.NoError(t, err, "should be no-op on non-halfcloser") + }) +} + +func TestPeekedConn_ReplayAndPassthrough(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + peeked := []byte("peeked-data") + subsequent := []byte("subsequent-data") + + wrapped := newPeekedConn(serverConn, peeked) + + go func() { + _, _ = clientConn.Write(subsequent) + }() + + // Read should return peeked data first. + buf := make([]byte, len(peeked)) + n, err := io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, peeked, buf[:n]) + + // Then subsequent data from the real connection. + buf = make([]byte, len(subsequent)) + n, err = io.ReadFull(wrapped, buf) + require.NoError(t, err) + assert.Equal(t, subsequent, buf[:n]) +} diff --git a/proxy/internal/types/types.go b/proxy/internal/types/types.go index 41acfef40..bf3731803 100644 --- a/proxy/internal/types/types.go +++ b/proxy/internal/types/types.go @@ -1,5 +1,56 @@ // Package types defines common types used across the proxy package. package types +import ( + "context" + "net" + "time" +) + // AccountID represents a unique identifier for a NetBird account. type AccountID string + +// ServiceID represents a unique identifier for a proxy service. +type ServiceID string + +// ServiceMode describes how a reverse proxy service is exposed. +type ServiceMode string + +const ( + ServiceModeHTTP ServiceMode = "http" + ServiceModeTCP ServiceMode = "tcp" + ServiceModeUDP ServiceMode = "udp" + ServiceModeTLS ServiceMode = "tls" +) + +// IsL4 returns true for TCP, UDP, and TLS modes. +func (m ServiceMode) IsL4() bool { + return m == ServiceModeTCP || m == ServiceModeUDP || m == ServiceModeTLS +} + +// RelayDirection indicates the direction of a relayed packet. +type RelayDirection string + +const ( + RelayDirectionClientToBackend RelayDirection = "client_to_backend" + RelayDirectionBackendToClient RelayDirection = "backend_to_client" +) + +// DialContextFunc dials a backend through the WireGuard tunnel. +type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// dialTimeoutKey is the context key for a per-request dial timeout. +type dialTimeoutKey struct{} + +// WithDialTimeout returns a context carrying a dial timeout that +// DialContext wrappers can use to scope the timeout to just the +// connection establishment phase. +func WithDialTimeout(ctx context.Context, d time.Duration) context.Context { + return context.WithValue(ctx, dialTimeoutKey{}, d) +} + +// DialTimeoutFromContext returns the dial timeout from the context, if set. +func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) { + d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration) + return d, ok && d > 0 +} diff --git a/proxy/internal/types/types_test.go b/proxy/internal/types/types_test.go new file mode 100644 index 000000000..dd9738442 --- /dev/null +++ b/proxy/internal/types/types_test.go @@ -0,0 +1,54 @@ +package types + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestServiceMode_IsL4(t *testing.T) { + tests := []struct { + mode ServiceMode + want bool + }{ + {ServiceModeHTTP, false}, + {ServiceModeTCP, true}, + {ServiceModeUDP, true}, + {ServiceModeTLS, true}, + {ServiceMode("unknown"), false}, + } + + for _, tt := range tests { + t.Run(string(tt.mode), func(t *testing.T) { + assert.Equal(t, tt.want, tt.mode.IsL4()) + }) + } +} + +func TestDialTimeoutContext(t *testing.T) { + t.Run("round trip", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), 5*time.Second) + d, ok := DialTimeoutFromContext(ctx) + assert.True(t, ok) + assert.Equal(t, 5*time.Second, d) + }) + + t.Run("missing", func(t *testing.T) { + _, ok := DialTimeoutFromContext(context.Background()) + assert.False(t, ok) + }) + + t.Run("zero returns false", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), 0) + _, ok := DialTimeoutFromContext(ctx) + assert.False(t, ok, "zero duration should return ok=false") + }) + + t.Run("negative returns false", func(t *testing.T) { + ctx := WithDialTimeout(context.Background(), -1*time.Second) + _, ok := DialTimeoutFromContext(ctx) + assert.False(t, ok, "negative duration should return ok=false") + }) +} diff --git a/proxy/internal/udp/relay.go b/proxy/internal/udp/relay.go new file mode 100644 index 000000000..f2f58e858 --- /dev/null +++ b/proxy/internal/udp/relay.go @@ -0,0 +1,496 @@ +package udp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" + + "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/netutil" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +const ( + // DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup. + DefaultSessionTTL = 30 * time.Second + // cleanupInterval is how often the cleaner goroutine runs. + cleanupInterval = time.Minute + // maxPacketSize is the maximum UDP packet size we'll handle. + maxPacketSize = 65535 + // DefaultMaxSessions is the default cap on concurrent UDP sessions per relay. + DefaultMaxSessions = 1024 + // sessionCreateRate limits new session creation per second. + sessionCreateRate = 50 + // sessionCreateBurst is the burst allowance for session creation. + sessionCreateBurst = 100 + // defaultDialTimeout is the fallback dial timeout for backend connections. + defaultDialTimeout = 30 * time.Second +) + +// l4Logger sends layer-4 access log entries to the management server. +type l4Logger interface { + LogL4(entry accesslog.L4Entry) +} + +// SessionObserver receives callbacks for UDP session lifecycle events. +// All methods must be safe for concurrent use. +type SessionObserver interface { + UDPSessionStarted(accountID types.AccountID) + UDPSessionEnded(accountID types.AccountID) + UDPSessionDialError(accountID types.AccountID) + UDPSessionRejected(accountID types.AccountID) + UDPPacketRelayed(direction types.RelayDirection, bytes int) +} + +// clientAddr is a typed key for UDP session lookups. +type clientAddr string + +// Relay listens for incoming UDP packets on a dedicated port and +// maintains per-client sessions that relay packets to a backend +// through the WireGuard tunnel. +type Relay struct { + logger *log.Entry + listener net.PacketConn + target string + domain string + accountID types.AccountID + serviceID types.ServiceID + dialFunc types.DialContextFunc + dialTimeout time.Duration + sessionTTL time.Duration + maxSessions int + + mu sync.RWMutex + sessions map[clientAddr]*session + + bufPool sync.Pool + sessLimiter *rate.Limiter + sessWg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + observer SessionObserver + accessLog l4Logger +} + +type session struct { + backend net.Conn + addr net.Addr + createdAt time.Time + // lastSeen stores the last activity timestamp as unix nanoseconds. + lastSeen atomic.Int64 + cancel context.CancelFunc + // bytesIn tracks total bytes received from the client. + bytesIn atomic.Int64 + // bytesOut tracks total bytes sent back to the client. + bytesOut atomic.Int64 +} + +func (s *session) updateLastSeen() { + s.lastSeen.Store(time.Now().UnixNano()) +} + +func (s *session) idleDuration() time.Duration { + return time.Since(time.Unix(0, s.lastSeen.Load())) +} + +// RelayConfig holds the configuration for a UDP relay. +type RelayConfig struct { + Logger *log.Entry + Listener net.PacketConn + Target string + Domain string + AccountID types.AccountID + ServiceID types.ServiceID + DialFunc types.DialContextFunc + DialTimeout time.Duration + SessionTTL time.Duration + MaxSessions int + AccessLog l4Logger +} + +// New creates a UDP relay for the given listener and backend target. +// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions. +// DialTimeout controls how long to wait for backend connections; use 0 for default. +// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL. +func New(parentCtx context.Context, cfg RelayConfig) *Relay { + maxSessions := cfg.MaxSessions + dialTimeout := cfg.DialTimeout + sessionTTL := cfg.SessionTTL + if maxSessions <= 0 { + maxSessions = DefaultMaxSessions + } + if dialTimeout <= 0 { + dialTimeout = defaultDialTimeout + } + if sessionTTL <= 0 { + sessionTTL = DefaultSessionTTL + } + ctx, cancel := context.WithCancel(parentCtx) + return &Relay{ + logger: cfg.Logger, + listener: cfg.Listener, + target: cfg.Target, + domain: cfg.Domain, + accountID: cfg.AccountID, + serviceID: cfg.ServiceID, + accessLog: cfg.AccessLog, + dialFunc: cfg.DialFunc, + dialTimeout: dialTimeout, + sessionTTL: sessionTTL, + maxSessions: maxSessions, + sessions: make(map[clientAddr]*session), + bufPool: sync.Pool{ + New: func() any { + buf := make([]byte, maxPacketSize) + return &buf + }, + }, + sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst), + ctx: ctx, + cancel: cancel, + } +} + +// ServiceID returns the service ID associated with this relay. +func (r *Relay) ServiceID() types.ServiceID { + return r.serviceID +} + +// SetObserver sets the session lifecycle observer. Must be called before Serve. +func (r *Relay) SetObserver(obs SessionObserver) { + r.observer = obs +} + +// Serve starts the relay loop. It blocks until the context is canceled +// or the listener is closed. +func (r *Relay) Serve() { + go r.cleanupLoop() + + for { + bufp := r.bufPool.Get().(*[]byte) + buf := *bufp + + n, addr, err := r.listener.ReadFrom(buf) + if err != nil { + r.bufPool.Put(bufp) + if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + return + } + r.logger.Debugf("UDP read: %v", err) + continue + } + + data := buf[:n] + sess, err := r.getOrCreateSession(addr) + if err != nil { + r.bufPool.Put(bufp) + r.logger.Debugf("create UDP session for %s: %v", addr, err) + continue + } + + sess.updateLastSeen() + + nw, err := sess.backend.Write(data) + if err != nil { + r.bufPool.Put(bufp) + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP write to backend for %s: %v", addr, err) + } + r.removeSession(sess) + continue + } + sess.bytesIn.Add(int64(nw)) + + if r.observer != nil { + r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw) + } + r.bufPool.Put(bufp) + } +} + +// getOrCreateSession returns an existing session or creates a new one. +func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { + key := clientAddr(addr.String()) + + r.mu.RLock() + sess, ok := r.sessions[key] + r.mu.RUnlock() + if ok && sess != nil { + return sess, nil + } + + // Check before taking the write lock: if the relay is shutting down, + // don't create new sessions. This prevents orphaned goroutines when + // Serve() processes a packet that was already read before Close(). + if r.ctx.Err() != nil { + return nil, r.ctx.Err() + } + + r.mu.Lock() + + if sess, ok = r.sessions[key]; ok && sess != nil { + r.mu.Unlock() + return sess, nil + } + if ok { + // Another goroutine is dialing for this key, skip. + r.mu.Unlock() + return nil, fmt.Errorf("session dial in progress for %s", key) + } + + if len(r.sessions) >= r.maxSessions { + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionRejected(r.accountID) + } + return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions) + } + + if !r.sessLimiter.Allow() { + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionRejected(r.accountID) + } + return nil, fmt.Errorf("session creation rate limited") + } + + // Reserve the slot with a nil session so concurrent callers for the same + // key see it exists and wait. Release the lock before dialing. + r.sessions[key] = nil + r.mu.Unlock() + + dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout) + backend, err := r.dialFunc(dialCtx, "udp", r.target) + dialCancel() + if err != nil { + r.mu.Lock() + delete(r.sessions, key) + r.mu.Unlock() + if r.observer != nil { + r.observer.UDPSessionDialError(r.accountID) + } + return nil, fmt.Errorf("dial backend %s: %w", r.target, err) + } + + sessCtx, sessCancel := context.WithCancel(r.ctx) + sess = &session{ + backend: backend, + addr: addr, + createdAt: time.Now(), + cancel: sessCancel, + } + sess.updateLastSeen() + + r.mu.Lock() + r.sessions[key] = sess + r.mu.Unlock() + + if r.observer != nil { + r.observer.UDPSessionStarted(r.accountID) + } + + r.sessWg.Go(func() { + r.relayBackendToClient(sessCtx, sess) + }) + + r.logger.Debugf("UDP session created for %s", addr) + return sess, nil +} + +// relayBackendToClient reads packets from the backend and writes them +// back to the client through the public-facing listener. +func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) { + bufp := r.bufPool.Get().(*[]byte) + defer r.bufPool.Put(bufp) + defer r.removeSession(sess) + + for ctx.Err() == nil { + data, ok := r.readBackendPacket(sess, *bufp) + if !ok { + return + } + if data == nil { + continue + } + + sess.updateLastSeen() + + nw, err := r.listener.WriteTo(data, sess.addr) + if err != nil { + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP write to client %s: %v", sess.addr, err) + } + return + } + sess.bytesOut.Add(int64(nw)) + + if r.observer != nil { + r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw) + } + } +} + +// readBackendPacket reads one packet from the backend with an idle deadline. +// Returns (data, true) on success, (nil, true) on idle timeout that should +// retry, or (nil, false) when the session should be torn down. +func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) { + if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil { + r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err) + return nil, false + } + + n, err := sess.backend.Read(buf) + if err != nil { + if netutil.IsTimeout(err) { + if sess.idleDuration() > r.sessionTTL { + return nil, false + } + return nil, true + } + if !netutil.IsExpectedError(err) { + r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err) + } + return nil, false + } + + return buf[:n], true +} + +// cleanupLoop periodically removes idle sessions. +func (r *Relay) cleanupLoop() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-r.ctx.Done(): + return + case <-ticker.C: + r.cleanupIdleSessions() + } + } +} + +// cleanupIdleSessions closes sessions that have been idle for too long. +func (r *Relay) cleanupIdleSessions() { + var expired []*session + + r.mu.Lock() + for key, sess := range r.sessions { + if sess == nil { + continue + } + idle := sess.idleDuration() + if idle > r.sessionTTL { + r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load()) + delete(r.sessions, key) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close idle session %s backend: %v", sess.addr, err) + } + expired = append(expired, sess) + } + } + r.mu.Unlock() + + for _, sess := range expired { + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } +} + +// removeSession removes a session from the map if it still matches the +// given pointer. This is safe to call concurrently with cleanupIdleSessions +// because the identity check prevents double-close when both paths race. +func (r *Relay) removeSession(sess *session) { + r.mu.Lock() + key := clientAddr(sess.addr.String()) + removed := r.sessions[key] == sess + if removed { + delete(r.sessions, key) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close session %s backend: %v", sess.addr, err) + } + } + r.mu.Unlock() + + if removed { + r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } +} + +// logSessionEnd sends an access log entry for a completed UDP session. +func (r *Relay) logSessionEnd(sess *session) { + if r.accessLog == nil { + return + } + + var sourceIP netip.Addr + if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil { + sourceIP = ap.Addr().Unmap() + } + + r.accessLog.LogL4(accesslog.L4Entry{ + AccountID: r.accountID, + ServiceID: r.serviceID, + Protocol: accesslog.ProtocolUDP, + Host: r.domain, + SourceIP: sourceIP, + DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(), + BytesUpload: sess.bytesIn.Load(), + BytesDownload: sess.bytesOut.Load(), + }) +} + +// Close stops the relay, waits for all session goroutines to exit, +// and cleans up remaining sessions. +func (r *Relay) Close() { + r.cancel() + if err := r.listener.Close(); err != nil { + r.logger.Debugf("close UDP listener: %v", err) + } + + var closedSessions []*session + r.mu.Lock() + for key, sess := range r.sessions { + if sess == nil { + delete(r.sessions, key) + continue + } + r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)", + sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) + sess.cancel() + if err := sess.backend.Close(); err != nil { + r.logger.Debugf("close session %s backend: %v", sess.addr, err) + } + delete(r.sessions, key) + closedSessions = append(closedSessions, sess) + } + r.mu.Unlock() + + for _, sess := range closedSessions { + if r.observer != nil { + r.observer.UDPSessionEnded(r.accountID) + } + r.logSessionEnd(sess) + } + + r.sessWg.Wait() +} diff --git a/proxy/internal/udp/relay_test.go b/proxy/internal/udp/relay_test.go new file mode 100644 index 000000000..a1e91b290 --- /dev/null +++ b/proxy/internal/udp/relay_test.go @@ -0,0 +1,493 @@ +package udp + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" +) + +func TestRelay_BasicPacketExchange(t *testing.T) { + // Set up a UDP backend that echoes packets. + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + // Set up the relay's public-facing listener. + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + backendAddr := backend.LocalAddr().String() + + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backendAddr, DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Create a client and send a packet to the relay. + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + defer client.Close() + + testData := []byte("hello UDP relay") + _, err = client.Write(testData) + require.NoError(t, err) + + // Read the echoed response. + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := client.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n], "should receive echoed packet") +} + +func TestRelay_MultipleClients(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Two clients, each should get their own session. + for i, msg := range []string{"client-1", "client-2"} { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + defer client.Close() + + _, err = client.Write([]byte(msg)) + require.NoError(t, err) + + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := client.Read(buf) + require.NoError(t, err, "client %d read", i) + assert.Equal(t, msg, string(buf[:n]), "client %d should get own echo", i) + } + + // Verify two sessions were created. + relay.mu.RLock() + sessionCount := len(relay.sessions) + relay.mu.RUnlock() + assert.Equal(t, 2, sessionCount, "should have two sessions") +} + +func TestRelay_Close(t *testing.T) { + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: "127.0.0.1:9999", DialFunc: dialFunc}) + + done := make(chan struct{}) + go func() { + relay.Serve() + close(done) + }() + + relay.Close() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Serve did not return after Close") + } +} + +func TestRelay_SessionCleanup(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay.Serve() + defer relay.Close() + + // Create a session. + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err) + client.Close() + + // Verify session exists. + relay.mu.RLock() + assert.Equal(t, 1, len(relay.sessions)) + relay.mu.RUnlock() + + // Make session appear idle by setting lastSeen to the past. + relay.mu.Lock() + for _, sess := range relay.sessions { + sess.lastSeen.Store(time.Now().Add(-2 * DefaultSessionTTL).UnixNano()) + } + relay.mu.Unlock() + + // Trigger cleanup manually. + relay.cleanupIdleSessions() + + relay.mu.RLock() + assert.Equal(t, 0, len(relay.sessions), "idle sessions should be cleaned up") + relay.mu.RUnlock() +} + +// TestRelay_CloseAndRecreate verifies that closing a relay and creating a new +// one on the same port works cleanly (simulates port mapping modify cycle). +func TestRelay_CloseAndRecreate(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + // First relay. + ln1, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + relay1 := New(ctx, RelayConfig{Logger: logger, Listener: ln1, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay1.Serve() + + client1, err := net.Dial("udp", ln1.LocalAddr().String()) + require.NoError(t, err) + _, err = client1.Write([]byte("relay1")) + require.NoError(t, err) + require.NoError(t, client1.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + n, err := client1.Read(buf) + require.NoError(t, err) + assert.Equal(t, "relay1", string(buf[:n])) + client1.Close() + + // Close first relay. + relay1.Close() + + // Second relay on same port. + port := ln1.LocalAddr().(*net.UDPAddr).Port + ln2, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", port)) + require.NoError(t, err) + + relay2 := New(ctx, RelayConfig{Logger: logger, Listener: ln2, Target: backend.LocalAddr().String(), DialFunc: dialFunc}) + go relay2.Serve() + defer relay2.Close() + + client2, err := net.Dial("udp", ln2.LocalAddr().String()) + require.NoError(t, err) + defer client2.Close() + _, err = client2.Write([]byte("relay2")) + require.NoError(t, err) + require.NoError(t, client2.SetReadDeadline(time.Now().Add(2*time.Second))) + n, err = client2.Read(buf) + require.NoError(t, err) + assert.Equal(t, "relay2", string(buf[:n]), "second relay should work on same port") +} + +func TestRelay_SessionLimit(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + // Create a relay with a max of 2 sessions. + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc, MaxSessions: 2}) + go relay.Serve() + defer relay.Close() + + // Create 2 clients to fill up the session limit. + for i := range 2 { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + defer client.Close() + + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err, "client %d should get response", i) + } + + relay.mu.RLock() + assert.Equal(t, 2, len(relay.sessions), "should have exactly 2 sessions") + relay.mu.RUnlock() + + // Third client should get its packet dropped (session creation fails). + client3, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err) + defer client3.Close() + + _, err = client3.Write([]byte("should be dropped")) + require.NoError(t, err) + + require.NoError(t, client3.SetReadDeadline(time.Now().Add(500*time.Millisecond))) + buf := make([]byte, 1024) + _, err = client3.Read(buf) + assert.Error(t, err, "third client should time out because session was rejected") + + relay.mu.RLock() + assert.Equal(t, 2, len(relay.sessions), "session count should not exceed limit") + relay.mu.RUnlock() +} + +// testObserver records UDP session lifecycle events for test assertions. +type testObserver struct { + mu sync.Mutex + started int + ended int + rejected int + dialErr int + packets int + bytes int +} + +func (o *testObserver) UDPSessionStarted(types.AccountID) { o.mu.Lock(); o.started++; o.mu.Unlock() } +func (o *testObserver) UDPSessionEnded(types.AccountID) { o.mu.Lock(); o.ended++; o.mu.Unlock() } +func (o *testObserver) UDPSessionDialError(types.AccountID) { o.mu.Lock(); o.dialErr++; o.mu.Unlock() } +func (o *testObserver) UDPSessionRejected(types.AccountID) { o.mu.Lock(); o.rejected++; o.mu.Unlock() } +func (o *testObserver) UDPPacketRelayed(_ types.RelayDirection, b int) { + o.mu.Lock() + o.packets++ + o.bytes += b + o.mu.Unlock() +} + +func TestRelay_CloseFiresObserverEnded(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + obs := &testObserver{} + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc}) + relay.SetObserver(obs) + go relay.Serve() + + // Create two sessions. + for i := range 2 { + client, err := net.Dial("udp", listener.LocalAddr().String()) + require.NoError(t, err, "client %d", i) + + _, err = client.Write([]byte("hello")) + require.NoError(t, err) + + require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second))) + buf := make([]byte, 1024) + _, err = client.Read(buf) + require.NoError(t, err) + client.Close() + } + + obs.mu.Lock() + assert.Equal(t, 2, obs.started, "should have 2 started events") + obs.mu.Unlock() + + // Close should fire UDPSessionEnded for all remaining sessions. + relay.Close() + + obs.mu.Lock() + assert.Equal(t, 2, obs.ended, "Close should fire UDPSessionEnded for each session") + obs.mu.Unlock() +} + +func TestRelay_SessionRateLimit(t *testing.T) { + backend, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer backend.Close() + + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := backend.ReadFrom(buf) + if err != nil { + return + } + _, _ = backend.WriteTo(buf[:n], addr) + } + }() + + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewEntry(log.StandardLogger()) + dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, address) + } + + obs := &testObserver{} + // High max sessions (1000) but the relay uses a rate limiter internally + // (default: 50/s burst 100). We exhaust the burst by creating sessions + // rapidly, then verify that subsequent creates are rejected. + relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc, MaxSessions: 1000}) + relay.SetObserver(obs) + go relay.Serve() + defer relay.Close() + + // Exhaust the burst by calling getOrCreateSession directly with + // synthetic addresses. This is faster than real UDP round-trips. + for i := range sessionCreateBurst + 20 { + addr := &net.UDPAddr{IP: net.IPv4(10, 0, byte(i/256), byte(i%256)), Port: 10000 + i} + _, _ = relay.getOrCreateSession(addr) + } + + obs.mu.Lock() + rejected := obs.rejected + obs.mu.Unlock() + + assert.Greater(t, rejected, 0, "some sessions should be rate-limited") +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 6a0ecce30..ebecfc6f6 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -243,6 +243,10 @@ func (c *testProxyController) GetProxiesForCluster(_ string) []string { return nil } +func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { + return nil +} + // storeBackedServiceManager reads directly from the real store. type storeBackedServiceManager struct { store store.Store @@ -505,15 +509,15 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T nil, "", 0, - mapping.GetAccountId(), - mapping.GetId(), + proxytypes.AccountID(mapping.GetAccountId()), + proxytypes.ServiceID(mapping.GetId()), ) require.NoError(t, err) // Apply to real proxy (idempotent) proxyHandler.AddMapping(proxy.Mapping{ Host: mapping.GetDomain(), - ID: mapping.GetId(), + ID: proxytypes.ServiceID(mapping.GetId()), AccountID: proxytypes.AccountID(mapping.GetAccountId()), }) } diff --git a/proxy/server.go b/proxy/server.go index 62e8368e6..649d49c9a 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -30,6 +30,7 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/exporters/prometheus" "go.opentelemetry.io/otel/sdk/metric" + "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -46,15 +47,26 @@ import ( "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/k8s" proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/roundtrip" + nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" "github.com/netbirdio/netbird/proxy/internal/types" + udprelay "github.com/netbirdio/netbird/proxy/internal/udp" "github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util/embeddedroots" ) + +// portRouter bundles a per-port Router with its listener and cancel func. +type portRouter struct { + router *nbtcp.Router + listener net.Listener + cancel context.CancelFunc +} + type Server struct { mgmtClient proto.ProxyServiceClient proxy *proxy.ReverseProxy @@ -67,12 +79,27 @@ type Server struct { healthServer *health.Server healthChecker *health.Checker meter *proxymetrics.Metrics + accessLog *accesslog.Logger + mainRouter *nbtcp.Router + mainPort uint16 + udpMu sync.Mutex + udpRelays map[types.ServiceID]*udprelay.Relay + udpRelayWg sync.WaitGroup + portMu sync.RWMutex + portRouters map[uint16]*portRouter + svcPorts map[types.ServiceID][]uint16 + lastMappings map[types.ServiceID]*proto.ProxyMapping + portRouterWg sync.WaitGroup // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) // so they can be closed during graceful shutdown, since http.Server.Shutdown // does not handle them. hijackTracker conntrack.HijackTracker + // routerReady is closed once mainRouter is fully initialized. + // The mapping worker waits on this before processing updates. + routerReady chan struct{} + // Mostly used for debugging on management. startTime time.Time @@ -118,28 +145,36 @@ type Server struct { // When set, forwarding headers from these sources are preserved and // appended to instead of being stripped. TrustedProxies []netip.Prefix - // WireguardPort is the port for the WireGuard interface. Use 0 for a - // random OS-assigned port. A fixed port only works with single-account - // deployments; multiple accounts will fail to bind the same port. - WireguardPort int + // WireguardPort is the port for the NetBird tunnel interface. Use 0 + // for a random OS-assigned port. A fixed port only works with + // single-account deployments; multiple accounts will fail to bind + // the same port. + WireguardPort uint16 // ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners. // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. ProxyProtocol bool // PreSharedKey used for tunnel between proxy and peers (set globally not per account) PreSharedKey string + // SupportsCustomPorts indicates whether the proxy can bind arbitrary + // ports for TCP/UDP/TLS services. + SupportsCustomPorts bool + // DefaultDialTimeout is the default timeout for establishing backend + // connections when no per-service timeout is configured. Zero means + // each transport uses its own hardcoded default (typically 30s). + DefaultDialTimeout time.Duration } -// NotifyStatus sends a status update to management about tunnel connectivity -func (s *Server) NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error { +// NotifyStatus sends a status update to management about tunnel connectivity. +func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { status := proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED if connected { status = proto.ProxyStatus_PROXY_STATUS_ACTIVE } _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ - ServiceId: serviceID, - AccountId: accountID, + ServiceId: string(serviceID), + AccountId: string(accountID), Status: status, CertificateIssued: false, }) @@ -147,10 +182,10 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID, serviceID, domain } // NotifyCertificateIssued sends a notification to management that a certificate was issued -func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error { +func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error { _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ - ServiceId: serviceID, - AccountId: accountID, + ServiceId: string(serviceID), + AccountId: string(accountID), Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE, CertificateIssued: true, }) @@ -159,6 +194,11 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.initDefaults() + s.routerReady = make(chan struct{}) + s.udpRelays = make(map[types.ServiceID]*udprelay.Relay) + s.portRouters = make(map[uint16]*portRouter) + s.svcPorts = make(map[types.ServiceID][]uint16) + s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping) exporter, err := prometheus.New() if err != nil { @@ -184,7 +224,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { } }() s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) - go s.newManagementMappingWorker(ctx, s.mgmtClient) + runCtx, runCancel := context.WithCancel(ctx) + defer runCancel() + go s.newManagementMappingWorker(runCtx, s.mgmtClient) // Initialize the netbird client, this is required to build peer connections // to proxy over. @@ -206,7 +248,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) // Configure Access logs to management server. - accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) + s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) s.healthChecker = health.NewChecker(s.Logger, s.netbird) @@ -220,18 +262,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { handler := http.Handler(s.proxy) handler = s.auth.Protect(handler) handler = web.AssetHandler(handler) - handler = accessLog.Middleware(handler) + handler = s.accessLog.Middleware(handler) handler = s.meter.Middleware(handler) handler = s.hijackTracker.Middleware(handler) - // Start the reverse proxy HTTPS server. - s.https = &http.Server{ - Addr: addr, - Handler: handler, - TLSConfig: tlsConfig, - ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), - } - + // Start a raw TCP listener; the SNI router peeks at ClientHello + // and routes to either the HTTP handler or a TCP relay. lc := net.ListenConfig{} ln, err := lc.Listen(ctx, "tcp", addr) if err != nil { @@ -240,11 +276,34 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { if s.ProxyProtocol { ln = s.wrapProxyProtocol(ln) } + s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid + + // Set up the SNI router for TCP/HTTP multiplexing on the main port. + s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr()) + s.mainRouter.SetObserver(s.meter) + s.mainRouter.SetAccessLogger(s.accessLog) + close(s.routerReady) + + // The HTTP server uses the chanListener fed by the SNI router. + s.https = &http.Server{ + Addr: addr, + Handler: handler, + TLSConfig: tlsConfig, + ReadHeaderTimeout: httpReadHeaderTimeout, + IdleTimeout: httpIdleTimeout, + ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), + } httpsErr := make(chan error, 1) go func() { - s.Logger.Debugf("starting reverse proxy server on %s", addr) - httpsErr <- s.https.ServeTLS(ln, "", "") + s.Logger.Debug("starting HTTPS server on SNI router HTTP channel") + httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "") + }() + + routerErr := make(chan error, 1) + go func() { + s.Logger.Debugf("starting SNI router on %s", addr) + routerErr <- s.mainRouter.Serve(runCtx, ln) }() select { @@ -254,6 +313,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { return fmt.Errorf("https server: %w", err) } return nil + case err := <-routerErr: + s.shutdownServices() + if err != nil { + return fmt.Errorf("SNI router: %w", err) + } + return nil case <-ctx.Done(): s.gracefulShutdown() return nil @@ -381,6 +446,13 @@ const ( // shutdownServiceTimeout is the maximum time to wait for auxiliary // services (health probe, debug endpoint, ACME) to shut down. shutdownServiceTimeout = 5 * time.Second + + // httpReadHeaderTimeout limits how long the server waits to read + // request headers after accepting a connection. Prevents slowloris. + httpReadHeaderTimeout = 10 * time.Second + // httpIdleTimeout limits how long an idle keep-alive connection + // stays open before the server closes it. + httpIdleTimeout = 120 * time.Second ) func (s *Server) dialManagement() (*grpc.ClientConn, error) { @@ -518,6 +590,9 @@ func (s *Server) gracefulShutdown() { s.Logger.Infof("closed %d hijacked connection(s)", n) } + // Drain all router relay connections (main + per-port) in parallel. + s.drainAllRouters(shutdownDrainTimeout) + // Step 5: Stop all remaining background services. s.shutdownServices() s.Logger.Info("graceful shutdown complete") @@ -525,6 +600,34 @@ func (s *Server) gracefulShutdown() { // shutdownServices stops all background services concurrently and waits for // them to finish. +// drainAllRouters drains active relay connections on the main router and +// all per-port routers in parallel, up to the given timeout. +func (s *Server) drainAllRouters(timeout time.Duration) { + var wg sync.WaitGroup + + drain := func(name string, router *nbtcp.Router) { + wg.Add(1) + go func() { + defer wg.Done() + if ok := router.Drain(timeout); !ok { + s.Logger.Warnf("timed out draining %s relay connections", name) + } + }() + } + + if s.mainRouter != nil { + drain("main router", s.mainRouter) + } + + s.portMu.RLock() + for port, pr := range s.portRouters { + drain(fmt.Sprintf("port %d", port), pr.router) + } + s.portMu.RUnlock() + + wg.Wait() +} + func (s *Server) shutdownServices() { var wg sync.WaitGroup @@ -562,9 +665,165 @@ func (s *Server) shutdownServices() { }() } + // Close all UDP relays and wait for their goroutines to exit. + s.udpMu.Lock() + for id, relay := range s.udpRelays { + relay.Close() + delete(s.udpRelays, id) + } + s.udpMu.Unlock() + s.udpRelayWg.Wait() + + // Close all per-port routers. + s.portMu.Lock() + for port, pr := range s.portRouters { + pr.cancel() + if err := pr.listener.Close(); err != nil { + s.Logger.Debugf("close listener on port %d: %v", port, err) + } + delete(s.portRouters, port) + } + maps.Clear(s.svcPorts) + maps.Clear(s.lastMappings) + s.portMu.Unlock() + + // Wait for per-port router serve goroutines to exit. + s.portRouterWg.Wait() + wg.Wait() } +// resolveDialFunc returns a DialContextFunc that dials through the +// NetBird tunnel for the given account. +func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFunc, error) { + client, ok := s.netbird.GetClient(accountID) + if !ok { + return nil, fmt.Errorf("no client for account %s", accountID) + } + return client.DialContext, nil +} + +// notifyError reports a resource error back to management so it can be +// surfaced to the user (e.g. port bind failure, dialer resolution error). +func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) { + s.sendStatusUpdate(ctx, types.AccountID(mapping.GetAccountId()), types.ServiceID(mapping.GetId()), proto.ProxyStatus_PROXY_STATUS_ERROR, err) +} + +// sendStatusUpdate sends a status update for a service to management. +func (s *Server) sendStatusUpdate(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, st proto.ProxyStatus, err error) { + req := &proto.SendStatusUpdateRequest{ + ServiceId: string(serviceID), + AccountId: string(accountID), + Status: st, + } + if err != nil { + msg := err.Error() + req.ErrorMessage = &msg + } + if _, sendErr := s.mgmtClient.SendStatusUpdate(ctx, req); sendErr != nil { + s.Logger.Debugf("failed to send status update for %s: %v", serviceID, sendErr) + } +} + +// routerForPort returns the router that handles the given listen port. If port +// is 0 or matches the main listener port, the main router is returned. +// Otherwise a new per-port router is created and started. +func (s *Server) routerForPort(ctx context.Context, port uint16) (*nbtcp.Router, error) { + if port == 0 || port == s.mainPort { + return s.mainRouter, nil + } + return s.getOrCreatePortRouter(ctx, port) +} + +// routerForPortExisting returns the router for the given port without creating +// one. Returns the main router for port 0 / mainPort, or nil if no per-port +// router exists. +func (s *Server) routerForPortExisting(port uint16) *nbtcp.Router { + if port == 0 || port == s.mainPort { + return s.mainRouter + } + s.portMu.RLock() + pr := s.portRouters[port] + s.portMu.RUnlock() + if pr != nil { + return pr.router + } + return nil +} + +// getOrCreatePortRouter returns an existing per-port router or creates one +// with a new TCP listener and starts serving. +func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp.Router, error) { + s.portMu.Lock() + defer s.portMu.Unlock() + + if pr, ok := s.portRouters[port]; ok { + return pr.router, nil + } + + listenAddr := fmt.Sprintf(":%d", port) + ln, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, fmt.Errorf("listen TCP on %s: %w", listenAddr, err) + } + if s.ProxyProtocol { + ln = s.wrapProxyProtocol(ln) + } + + router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc) + router.SetObserver(s.meter) + router.SetAccessLogger(s.accessLog) + portCtx, cancel := context.WithCancel(ctx) + + s.portRouters[port] = &portRouter{ + router: router, + listener: ln, + cancel: cancel, + } + + s.portRouterWg.Add(1) + go func() { + defer s.portRouterWg.Done() + if err := router.Serve(portCtx, ln); err != nil { + s.Logger.Debugf("port %d router stopped: %v", port, err) + } + }() + + s.Logger.Debugf("started per-port router on %s", listenAddr) + return router, nil +} + +// cleanupPortIfEmpty tears down a per-port router if it has no remaining +// routes or fallback. The main port is never cleaned up. Active relay +// connections are drained before the listener is closed. +func (s *Server) cleanupPortIfEmpty(port uint16) { + if port == 0 || port == s.mainPort { + return + } + + s.portMu.Lock() + pr, ok := s.portRouters[port] + if !ok || !pr.router.IsEmpty() { + s.portMu.Unlock() + return + } + + // Cancel and close the listener while holding the lock so that + // getOrCreatePortRouter sees the entry is gone before we drain. + pr.cancel() + if err := pr.listener.Close(); err != nil { + s.Logger.Debugf("close listener on port %d: %v", port, err) + } + delete(s.portRouters, port) + s.portMu.Unlock() + + // Drain active relay connections outside the lock. + if ok := pr.router.Drain(nbtcp.DefaultDrainTimeout); !ok { + s.Logger.Warnf("timed out draining relay connections on port %d", port) + } + s.Logger.Debugf("cleaned up empty per-port router on port %d", port) +} + func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) { bo := &backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, @@ -590,6 +849,9 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr Version: s.Version, StartedAt: timestamppb.New(s.startTime), Address: s.ProxyURL, + Capabilities: &proto.ProxyCapabilities{ + SupportsCustomPorts: &s.SupportsCustomPorts, + }, }) if err != nil { return fmt.Errorf("create mapping stream: %w", err) @@ -626,6 +888,12 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr } func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error { + select { + case <-s.routerReady: + case <-ctx.Done(): + return ctx.Err() + } + for { // Check for context completion to gracefully shutdown. select { @@ -662,25 +930,28 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap s.Logger.WithFields(log.Fields{ "type": mapping.GetType(), "domain": mapping.GetDomain(), - "path": mapping.GetPath(), + "mode": mapping.GetMode(), + "port": mapping.GetListenPort(), "id": mapping.GetId(), }).Debug("Processing mapping update") switch mapping.GetType() { case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: if err := s.addMapping(ctx, mapping); err != nil { - // TODO: Retry this? Or maybe notify the management server that this mapping has failed? s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), "error": err, }).Error("Error adding new mapping, ignoring this mapping and continuing processing") + s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: - if err := s.updateMapping(ctx, mapping); err != nil { + if err := s.modifyMapping(ctx, mapping); err != nil { s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), - }).Errorf("failed to update mapping: %v", err) + "error": err, + }).Error("failed to modify mapping") + s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED: s.removeMapping(ctx, mapping) @@ -688,30 +959,89 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap } } +// addMapping registers a service mapping and starts the appropriate relay or routes. func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { - d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) - serviceID := mapping.GetId() + svcID := types.ServiceID(mapping.GetId()) authToken := mapping.GetAuthToken() - if err := s.netbird.AddPeer(ctx, accountID, d, authToken, serviceID); err != nil { - return fmt.Errorf("create peer for domain %q: %w", d, err) - } - var wildcardHit bool - if s.acme != nil { - wildcardHit = s.acme.AddDomain(d, string(accountID), serviceID) + svcKey := s.serviceKeyForMapping(mapping) + if err := s.netbird.AddPeer(ctx, accountID, svcKey, authToken, svcID); err != nil { + return fmt.Errorf("create peer for service %s: %w", svcID, err) } - // Pass the mapping through to the update function to avoid duplicating the - // setup, currently update is simply a subset of this function, so this - // separation makes sense...to me at least. + if err := s.setupMappingRoutes(ctx, mapping); err != nil { + s.cleanupMappingRoutes(mapping) + if peerErr := s.netbird.RemovePeer(ctx, accountID, svcKey); peerErr != nil { + s.Logger.WithError(peerErr).WithField("service_id", svcID).Warn("failed to remove peer after setup failure") + } + return err + } + s.storeMapping(mapping) + return nil +} + +// modifyMapping updates a service mapping in place without tearing down the +// NetBird peer. It cleans up old routes using the previously stored mapping +// state and re-applies them from the new mapping. +func (s *Server) modifyMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + if old := s.loadMapping(types.ServiceID(mapping.GetId())); old != nil { + s.cleanupMappingRoutes(old) + if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { + s.meter.L4ServiceRemoved(mode) + } + } else { + s.cleanupMappingRoutes(mapping) + } + if err := s.setupMappingRoutes(ctx, mapping); err != nil { + s.cleanupMappingRoutes(mapping) + return err + } + s.storeMapping(mapping) + return nil +} + +// setupMappingRoutes configures the appropriate routes or relays for the given +// service mapping based on its mode. The NetBird peer must already exist. +func (s *Server) setupMappingRoutes(ctx context.Context, mapping *proto.ProxyMapping) error { + switch types.ServiceMode(mapping.GetMode()) { + case types.ServiceModeTCP: + return s.setupTCPMapping(ctx, mapping) + case types.ServiceModeUDP: + return s.setupUDPMapping(ctx, mapping) + case types.ServiceModeTLS: + return s.setupTLSMapping(ctx, mapping) + default: + return s.setupHTTPMapping(ctx, mapping) + } +} + +// setupHTTPMapping configures HTTP reverse proxy, auth, and ACME routes. +func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + d := domain.Domain(mapping.GetDomain()) + accountID := types.AccountID(mapping.GetAccountId()) + svcID := types.ServiceID(mapping.GetId()) + + if len(mapping.GetPath()) == 0 { + return nil + } + + var wildcardHit bool + if s.acme != nil { + wildcardHit = s.acme.AddDomain(d, accountID, svcID) + } + s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + Type: nbtcp.RouteHTTP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + }) if err := s.updateMapping(ctx, mapping); err != nil { - s.removeMapping(ctx, mapping) return fmt.Errorf("update mapping for domain %q: %w", d, err) } if wildcardHit { - if err := s.NotifyCertificateIssued(ctx, string(accountID), serviceID, string(d)); err != nil { + if err := s.NotifyCertificateIssued(ctx, accountID, svcID, string(d)); err != nil { s.Logger.Warnf("notify certificate ready for domain %q: %v", d, err) } } @@ -719,56 +1049,386 @@ func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) er return nil } +// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port. +func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + port, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("TCP service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for TCP service %s", svcID) + } + + if s.WireguardPort != 0 && port == s.WireguardPort { + return fmt.Errorf("port %d conflicts with tunnel port", port) + } + + router, err := s.routerForPort(ctx, port) + if err != nil { + return fmt.Errorf("router for TCP port %d: %w", port, err) + } + + router.SetFallback(nbtcp.Route{ + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTCP, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + }) + + s.portMu.Lock() + s.svcPorts[svcID] = []uint16{port} + s.portMu.Unlock() + + s.meter.L4ServiceAdded(types.ServiceModeTCP) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// setupUDPMapping starts a UDP relay on the listen port. +func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + port, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("UDP service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for UDP service %s", svcID) + } + + if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil { + return fmt.Errorf("UDP relay for service %s: %w", svcID, err) + } + + s.meter.L4ServiceAdded(types.ServiceModeUDP) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// setupTLSMapping configures a TLS SNI-routed passthrough on the listen port. +func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMapping) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + tlsPort, err := netutil.ValidatePort(mapping.GetListenPort()) + if err != nil { + return fmt.Errorf("TLS service %s: %w", svcID, err) + } + + targetAddr := s.l4TargetAddress(mapping) + if targetAddr == "" { + return fmt.Errorf("empty target address for TLS service %s", svcID) + } + + if s.WireguardPort != 0 && tlsPort == s.WireguardPort { + return fmt.Errorf("port %d conflicts with tunnel port", tlsPort) + } + + router, err := s.routerForPort(ctx, tlsPort) + if err != nil { + return fmt.Errorf("router for TLS port %d: %w", tlsPort, err) + } + + router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTLS, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + }) + + if tlsPort != s.mainPort { + s.portMu.Lock() + s.svcPorts[svcID] = []uint16{tlsPort} + s.portMu.Unlock() + } + + s.Logger.WithFields(log.Fields{ + "domain": mapping.GetDomain(), + "target": targetAddr, + "port": tlsPort, + "service": svcID, + }).Info("TLS passthrough mapping added") + + s.meter.L4ServiceAdded(types.ServiceModeTLS) + s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) + return nil +} + +// serviceKeyForMapping returns the appropriate ServiceKey for a mapping. +// TCP/UDP use an ID-based key; HTTP/TLS use a domain-based key. +func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.ServiceKey { + switch types.ServiceMode(mapping.GetMode()) { + case types.ServiceModeTCP, types.ServiceModeUDP: + return roundtrip.L4ServiceKey(types.ServiceID(mapping.GetId())) + default: + return roundtrip.DomainServiceKey(mapping.GetDomain()) + } +} + +// l4TargetAddress extracts and validates the target address from a mapping's +// first path entry. Returns empty string if no paths exist or the address is +// not a valid host:port. +func (s *Server) l4TargetAddress(mapping *proto.ProxyMapping) string { + paths := mapping.GetPath() + if len(paths) == 0 { + return "" + } + target := paths[0].GetTarget() + if _, _, err := net.SplitHostPort(target); err != nil { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "target": target, + }).Warnf("invalid L4 target address: %v", err) + return "" + } + return target +} + +// l4ProxyProtocol returns whether the first target has PROXY protocol enabled. +func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool { + paths := mapping.GetPath() + if len(paths) == 0 { + return false + } + return paths[0].GetOptions().GetProxyProtocol() +} + +// l4DialTimeout returns the dial timeout from the first target's options, +// falling back to the server's DefaultDialTimeout. +func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration { + paths := mapping.GetPath() + if len(paths) > 0 { + if d := paths[0].GetOptions().GetRequestTimeout(); d != nil { + return d.AsDuration() + } + } + return s.DefaultDialTimeout +} + +// l4SessionIdleTimeout returns the configured session idle timeout from the +// mapping options, or 0 to use the relay's default. +func l4SessionIdleTimeout(mapping *proto.ProxyMapping) time.Duration { + paths := mapping.GetPath() + if len(paths) > 0 { + if d := paths[0].GetOptions().GetSessionIdleTimeout(); d != nil { + return d.AsDuration() + } + } + return 0 +} + +// addUDPRelay starts a UDP relay on the specified listen port. +func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, targetAddress string, listenPort uint16) error { + svcID := types.ServiceID(mapping.GetId()) + accountID := types.AccountID(mapping.GetAccountId()) + + if s.WireguardPort != 0 && listenPort == s.WireguardPort { + return fmt.Errorf("UDP port %d conflicts with tunnel port", listenPort) + } + + // Close existing relay if present (idempotent re-add). + s.removeUDPRelay(svcID) + + listenAddr := fmt.Sprintf(":%d", listenPort) + + listener, err := net.ListenPacket("udp", listenAddr) + if err != nil { + return fmt.Errorf("listen UDP on %s: %w", listenAddr, err) + } + + dialFn, err := s.resolveDialFunc(accountID) + if err != nil { + _ = listener.Close() + return fmt.Errorf("resolve dialer for UDP: %w", err) + } + + entry := s.Logger.WithFields(log.Fields{ + "target": targetAddress, + "listen_port": listenPort, + "service_id": svcID, + }) + + relay := udprelay.New(ctx, udprelay.RelayConfig{ + Logger: entry, + Listener: listener, + Target: targetAddress, + Domain: mapping.GetDomain(), + AccountID: accountID, + ServiceID: svcID, + DialFunc: dialFn, + DialTimeout: s.l4DialTimeout(mapping), + SessionTTL: l4SessionIdleTimeout(mapping), + AccessLog: s.accessLog, + }) + relay.SetObserver(s.meter) + + s.udpMu.Lock() + s.udpRelays[svcID] = relay + s.udpMu.Unlock() + + s.udpRelayWg.Go(relay.Serve) + entry.Info("UDP relay added") + return nil +} + func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) error { // Very simple implementation here, we don't touch the existing peer // connection or any existing TLS configuration, we simply overwrite // the auth and proxy mappings. // Note: this does require the management server to always send a // full mapping rather than deltas during a modification. + accountID := types.AccountID(mapping.GetAccountId()) + svcID := types.ServiceID(mapping.GetId()) + var schemes []auth.Scheme if mapping.GetAuth().GetPassword() { - schemes = append(schemes, auth.NewPassword(s.mgmtClient, mapping.GetId(), mapping.GetAccountId())) + schemes = append(schemes, auth.NewPassword(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetPin() { - schemes = append(schemes, auth.NewPin(s.mgmtClient, mapping.GetId(), mapping.GetAccountId())) + schemes = append(schemes, auth.NewPin(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetOidc() { - schemes = append(schemes, auth.NewOIDC(s.mgmtClient, mapping.GetId(), mapping.GetAccountId(), s.ForwardedProto)) + schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto)) } maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second - if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, mapping.GetAccountId(), mapping.GetId()); err != nil { + if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil { return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err) } - s.proxy.AddMapping(s.protoToMapping(mapping)) - s.meter.AddMapping(s.protoToMapping(mapping)) + m := s.protoToMapping(ctx, mapping) + s.proxy.AddMapping(m) + s.meter.AddMapping(m) return nil } +// removeMapping tears down routes/relays and the NetBird peer for a service. +// Uses the stored mapping state when available to ensure all previously +// configured routes are cleaned up. func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { - d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) - if err := s.netbird.RemovePeer(ctx, accountID, d); err != nil { + svcKey := s.serviceKeyForMapping(mapping) + if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil { s.Logger.WithFields(log.Fields{ "account_id": accountID, - "domain": d, + "service_id": mapping.GetId(), "error": err, - }).Error("Error removing NetBird peer connection for domain, continuing additional domain cleanup but peer connection may still exist") + }).Error("failed to remove NetBird peer, continuing cleanup") } - if s.acme != nil { - s.acme.RemoveDomain(d) + + if old := s.deleteMapping(types.ServiceID(mapping.GetId())); old != nil { + s.cleanupMappingRoutes(old) + if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { + s.meter.L4ServiceRemoved(mode) + } + } else { + s.cleanupMappingRoutes(mapping) } - s.auth.RemoveDomain(mapping.GetDomain()) - s.proxy.RemoveMapping(s.protoToMapping(mapping)) - s.meter.RemoveMapping(s.protoToMapping(mapping)) } -func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { +// cleanupMappingRoutes removes HTTP/TLS/L4 routes and custom port state for a +// service without touching the NetBird peer. This is used for both full +// removal and in-place modification of mappings. +func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) { + svcID := types.ServiceID(mapping.GetId()) + host := mapping.GetDomain() + + // HTTP/TLS cleanup (only relevant when a domain is set). + if host != "" { + d := domain.Domain(host) + if s.acme != nil { + s.acme.RemoveDomain(d) + } + s.auth.RemoveDomain(host) + if s.proxy.RemoveMapping(proxy.Mapping{Host: host}) { + s.meter.RemoveMapping(proxy.Mapping{Host: host}) + } + // Close hijacked connections (WebSocket) for this domain. + if n := s.hijackTracker.CloseByHost(host); n > 0 { + s.Logger.Debugf("closed %d hijacked connection(s) for %s", n, host) + } + // Remove SNI route from the main router (covers both HTTP and main-port TLS). + s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID) + } + + // Extract and delete tracked custom-port entries atomically. + s.portMu.Lock() + entries := s.svcPorts[svcID] + delete(s.svcPorts, svcID) + s.portMu.Unlock() + + for _, entry := range entries { + if router := s.routerForPortExisting(entry); router != nil { + if host != "" { + router.RemoveRoute(nbtcp.SNIHost(host), svcID) + } else { + router.RemoveFallback(svcID) + } + } + s.cleanupPortIfEmpty(entry) + } + + // UDP relay cleanup (idempotent). + s.removeUDPRelay(svcID) + +} + +// removeUDPRelay stops and removes a UDP relay by service ID. +func (s *Server) removeUDPRelay(svcID types.ServiceID) { + s.udpMu.Lock() + relay, ok := s.udpRelays[svcID] + if ok { + delete(s.udpRelays, svcID) + } + s.udpMu.Unlock() + + if ok { + relay.Close() + s.Logger.WithField("service_id", svcID).Info("UDP relay removed") + } +} + +func (s *Server) storeMapping(mapping *proto.ProxyMapping) { + s.portMu.Lock() + s.lastMappings[types.ServiceID(mapping.GetId())] = mapping + s.portMu.Unlock() +} + +func (s *Server) loadMapping(svcID types.ServiceID) *proto.ProxyMapping { + s.portMu.RLock() + m := s.lastMappings[svcID] + s.portMu.RUnlock() + return m +} + +func (s *Server) deleteMapping(svcID types.ServiceID) *proto.ProxyMapping { + s.portMu.Lock() + m := s.lastMappings[svcID] + delete(s.lastMappings, svcID) + s.portMu.Unlock() + return m +} + +func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping) proxy.Mapping { paths := make(map[string]*proxy.PathTarget) for _, pathMapping := range mapping.GetPath() { targetURL, err := url.Parse(pathMapping.GetTarget()) if err != nil { - // TODO: Should we warn management about this so it can be bubbled up to a user to reconfigure? s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "account_id": mapping.GetAccountId(), @@ -776,6 +1436,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { "path": pathMapping.GetPath(), "target": pathMapping.GetTarget(), }).WithError(err).Error("failed to parse target URL for path, skipping") + s.notifyError(ctx, mapping, fmt.Errorf("invalid target URL %q for path %q: %w", pathMapping.GetTarget(), pathMapping.GetPath(), err)) continue } @@ -788,10 +1449,13 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { pt.RequestTimeout = d.AsDuration() } } + if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 { + pt.RequestTimeout = s.DefaultDialTimeout + } paths[pathMapping.GetPath()] = pt } return proxy.Mapping{ - ID: mapping.GetId(), + ID: types.ServiceID(mapping.GetId()), AccountID: types.AccountID(mapping.GetAccountId()), Host: mapping.GetDomain(), Paths: paths, diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 9505b3fdf..333f0bf00 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -56,12 +56,14 @@ type ExposeRequest struct { Pin string Password string UserGroups []string + ListenPort uint16 } type ExposeResponse struct { - ServiceName string - Domain string - ServiceURL string + ServiceName string + Domain string + ServiceURL string + PortAutoAssigned bool } // NewClient creates a new client to Management service @@ -790,9 +792,10 @@ func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error { func fromProtoExposeResponse(resp *proto.ExposeServiceResponse) *ExposeResponse { return &ExposeResponse{ - ServiceName: resp.ServiceName, - Domain: resp.Domain, - ServiceURL: resp.ServiceUrl, + ServiceName: resp.ServiceName, + Domain: resp.Domain, + ServiceURL: resp.ServiceUrl, + PortAutoAssigned: resp.PortAutoAssigned, } } @@ -808,6 +811,8 @@ func toProtoExposeServiceRequest(req ExposeRequest) (*proto.ExposeServiceRequest protocol = proto.ExposeProtocol_EXPOSE_TCP case int(proto.ExposeProtocol_EXPOSE_UDP): protocol = proto.ExposeProtocol_EXPOSE_UDP + case int(proto.ExposeProtocol_EXPOSE_TLS): + protocol = proto.ExposeProtocol_EXPOSE_TLS default: return nil, fmt.Errorf("invalid expose protocol: %d", req.Protocol) } @@ -820,6 +825,7 @@ func toProtoExposeServiceRequest(req ExposeRequest) (*proto.ExposeServiceRequest Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, + ListenPort: uint32(req.ListenPort), }, nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 6d2967aa9..4b851bf19 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -2836,6 +2836,10 @@ components: format: int64 description: "Bytes downloaded (response body size)" example: 8192 + protocol: + type: string + description: "Protocol type: http, tcp, or udp" + example: "http" required: - id - service_id @@ -2954,6 +2958,20 @@ components: domain: type: string description: Domain for the service + mode: + type: string + description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + enum: [http, tcp, udp, tls] + default: http + listen_port: + type: integer + minimum: 0 + maximum: 65535 + description: Port the proxy listens on (L4/TLS only) + port_auto_assigned: + type: boolean + description: Whether the listen port was auto-assigned + readOnly: true proxy_cluster: type: string description: The proxy cluster handling this service (derived from domain) @@ -3020,6 +3038,16 @@ components: domain: type: string description: Domain for the service + mode: + type: string + description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + enum: [http, tcp, udp, tls] + default: http + listen_port: + type: integer + minimum: 0 + maximum: 65535 + description: Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. targets: type: array items: @@ -3040,8 +3068,6 @@ components: required: - name - domain - - targets - - auth - enabled ServiceTargetOptions: type: object @@ -3065,6 +3091,12 @@ components: additionalProperties: type: string pattern: '^[^\r\n]*$' + proxy_protocol: + type: boolean + description: Send PROXY Protocol v2 header to this backend (TCP/TLS only) + session_idle_timeout: + type: string + description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. ServiceTarget: type: object properties: @@ -3073,21 +3105,23 @@ components: description: Target ID target_type: type: string - description: Target type (e.g., "peer", "resource") - enum: [peer, resource] + description: Target type + enum: [peer, host, domain, subnet] path: type: string - description: URL path prefix for this target + description: URL path prefix for this target (HTTP only) protocol: type: string description: Protocol to use when connecting to the backend - enum: [http, https] + enum: [http, https, tcp, udp] host: type: string description: Backend ip or domain for this target port: type: integer - description: Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https). + minimum: 1 + maximum: 65535 + description: Backend port for this target enabled: type: boolean description: Whether this target is enabled @@ -3194,6 +3228,9 @@ components: target_cluster: type: string description: The proxy cluster this domain is validated against (only for custom domains) + supports_custom_ports: + type: boolean + description: Whether the cluster supports binding arbitrary TCP/UDP ports required: - id - domain @@ -4277,6 +4314,12 @@ components: requires_authentication: description: Requires authentication content: { } + conflict: + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' securitySchemes: BearerAuth: type: http @@ -9621,6 +9664,29 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' + /api/reverse-proxies/clusters: + get: + summary: List available proxy clusters + description: Returns a list of available proxy clusters with their connection status + tags: [ Services ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy clusters + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyCluster' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services: get: summary: List all Services @@ -9670,29 +9736,8 @@ paths: "$ref": "#/components/responses/requires_authentication" '403': "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" - /api/reverse-proxies/clusters: - get: - summary: List available proxy clusters - description: Returns a list of available proxy clusters with their connection status - tags: [ Services ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - responses: - '200': - description: A JSON Array of proxy clusters - content: - application/json: - schema: - type: array - items: - $ref: '#/components/schemas/ProxyCluster' - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" + '409': + "$ref": "#/components/responses/conflict" '500': "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services/{serviceId}: @@ -9762,6 +9807,8 @@ paths: "$ref": "#/components/responses/forbidden" '404': "$ref": "#/components/responses/not_found" + '409': + "$ref": "#/components/responses/conflict" '500': "$ref": "#/components/responses/internal_error" delete: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index f5a2b7ced..4ec3b871a 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -880,6 +880,30 @@ func (e SentinelOneMatchAttributesNetworkStatus) Valid() bool { } } +// Defines values for ServiceMode. +const ( + ServiceModeHttp ServiceMode = "http" + ServiceModeTcp ServiceMode = "tcp" + ServiceModeTls ServiceMode = "tls" + ServiceModeUdp ServiceMode = "udp" +) + +// Valid indicates whether the value is a known member of the ServiceMode enum. +func (e ServiceMode) Valid() bool { + switch e { + case ServiceModeHttp: + return true + case ServiceModeTcp: + return true + case ServiceModeTls: + return true + case ServiceModeUdp: + return true + default: + return false + } +} + // Defines values for ServiceMetaStatus. const ( ServiceMetaStatusActive ServiceMetaStatus = "active" @@ -910,10 +934,36 @@ func (e ServiceMetaStatus) Valid() bool { } } +// Defines values for ServiceRequestMode. +const ( + ServiceRequestModeHttp ServiceRequestMode = "http" + ServiceRequestModeTcp ServiceRequestMode = "tcp" + ServiceRequestModeTls ServiceRequestMode = "tls" + ServiceRequestModeUdp ServiceRequestMode = "udp" +) + +// Valid indicates whether the value is a known member of the ServiceRequestMode enum. +func (e ServiceRequestMode) Valid() bool { + switch e { + case ServiceRequestModeHttp: + return true + case ServiceRequestModeTcp: + return true + case ServiceRequestModeTls: + return true + case ServiceRequestModeUdp: + return true + default: + return false + } +} + // Defines values for ServiceTargetProtocol. const ( ServiceTargetProtocolHttp ServiceTargetProtocol = "http" ServiceTargetProtocolHttps ServiceTargetProtocol = "https" + ServiceTargetProtocolTcp ServiceTargetProtocol = "tcp" + ServiceTargetProtocolUdp ServiceTargetProtocol = "udp" ) // Valid indicates whether the value is a known member of the ServiceTargetProtocol enum. @@ -923,6 +973,10 @@ func (e ServiceTargetProtocol) Valid() bool { return true case ServiceTargetProtocolHttps: return true + case ServiceTargetProtocolTcp: + return true + case ServiceTargetProtocolUdp: + return true default: return false } @@ -930,16 +984,22 @@ func (e ServiceTargetProtocol) Valid() bool { // Defines values for ServiceTargetTargetType. const ( - ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" - ServiceTargetTargetTypeResource ServiceTargetTargetType = "resource" + ServiceTargetTargetTypeDomain ServiceTargetTargetType = "domain" + ServiceTargetTargetTypeHost ServiceTargetTargetType = "host" + ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" + ServiceTargetTargetTypeSubnet ServiceTargetTargetType = "subnet" ) // Valid indicates whether the value is a known member of the ServiceTargetTargetType enum. func (e ServiceTargetTargetType) Valid() bool { switch e { + case ServiceTargetTargetTypeDomain: + return true + case ServiceTargetTargetTypeHost: + return true case ServiceTargetTargetTypePeer: return true - case ServiceTargetTargetTypeResource: + case ServiceTargetTargetTypeSubnet: return true default: return false @@ -3249,6 +3309,9 @@ type ProxyAccessLog struct { // Path Path of the request Path string `json:"path"` + // Protocol Protocol type: http, tcp, or udp + Protocol *string `json:"protocol,omitempty"` + // Reason Reason for the request result (e.g., authentication failure) Reason *string `json:"reason,omitempty"` @@ -3313,6 +3376,9 @@ type ReverseProxyDomain struct { // Id Domain ID Id string `json:"id"` + // SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports + SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"` + // TargetCluster The proxy cluster this domain is validated against (only for custom domains) TargetCluster *string `json:"target_cluster,omitempty"` @@ -3505,8 +3571,14 @@ type Service struct { Enabled bool `json:"enabled"` // Id Service ID - Id string `json:"id"` - Meta ServiceMeta `json:"meta"` + Id string `json:"id"` + + // ListenPort Port the proxy listens on (L4/TLS only) + ListenPort *int `json:"listen_port,omitempty"` + Meta ServiceMeta `json:"meta"` + + // Mode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + Mode *ServiceMode `json:"mode,omitempty"` // Name Service name Name string `json:"name"` @@ -3514,6 +3586,9 @@ type Service struct { // PassHostHeader When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address PassHostHeader *bool `json:"pass_host_header,omitempty"` + // PortAutoAssigned Whether the listen port was auto-assigned + PortAutoAssigned *bool `json:"port_auto_assigned,omitempty"` + // ProxyCluster The proxy cluster handling this service (derived from domain) ProxyCluster *string `json:"proxy_cluster,omitempty"` @@ -3524,6 +3599,9 @@ type Service struct { Targets []ServiceTarget `json:"targets"` } +// ServiceMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. +type ServiceMode string + // ServiceAuthConfig defines model for ServiceAuthConfig. type ServiceAuthConfig struct { BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty"` @@ -3549,7 +3627,7 @@ type ServiceMetaStatus string // ServiceRequest defines model for ServiceRequest. type ServiceRequest struct { - Auth ServiceAuthConfig `json:"auth"` + Auth *ServiceAuthConfig `json:"auth,omitempty"` // Domain Domain for the service Domain string `json:"domain"` @@ -3557,6 +3635,12 @@ type ServiceRequest struct { // Enabled Whether the service is enabled Enabled bool `json:"enabled"` + // ListenPort Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. + ListenPort *int `json:"listen_port,omitempty"` + + // Mode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. + Mode *ServiceRequestMode `json:"mode,omitempty"` + // Name Service name Name string `json:"name"` @@ -3567,9 +3651,12 @@ type ServiceRequest struct { RewriteRedirects *bool `json:"rewrite_redirects,omitempty"` // Targets List of target backends for this service - Targets []ServiceTarget `json:"targets"` + Targets *[]ServiceTarget `json:"targets,omitempty"` } +// ServiceRequestMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. +type ServiceRequestMode string + // ServiceTarget defines model for ServiceTarget. type ServiceTarget struct { // Enabled Whether this target is enabled @@ -3579,10 +3666,10 @@ type ServiceTarget struct { Host *string `json:"host,omitempty"` Options *ServiceTargetOptions `json:"options,omitempty"` - // Path URL path prefix for this target + // Path URL path prefix for this target (HTTP only) Path *string `json:"path,omitempty"` - // Port Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https). + // Port Backend port for this target Port int `json:"port"` // Protocol Protocol to use when connecting to the backend @@ -3591,14 +3678,14 @@ type ServiceTarget struct { // TargetId Target ID TargetId string `json:"target_id"` - // TargetType Target type (e.g., "peer", "resource") + // TargetType Target type TargetType ServiceTargetTargetType `json:"target_type"` } // ServiceTargetProtocol Protocol to use when connecting to the backend type ServiceTargetProtocol string -// ServiceTargetTargetType Target type (e.g., "peer", "resource") +// ServiceTargetTargetType Target type type ServiceTargetTargetType string // ServiceTargetOptions defines model for ServiceTargetOptions. @@ -3609,9 +3696,15 @@ type ServiceTargetOptions struct { // PathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. PathRewrite *ServiceTargetOptionsPathRewrite `json:"path_rewrite,omitempty"` + // ProxyProtocol Send PROXY Protocol v2 header to this backend (TCP/TLS only) + ProxyProtocol *bool `json:"proxy_protocol,omitempty"` + // RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m") RequestTimeout *string `json:"request_timeout,omitempty"` + // SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. + SessionIdleTimeout *string `json:"session_idle_timeout,omitempty"` + // SkipTlsVerify Skip TLS certificate verification for this backend SkipTlsVerify *bool `json:"skip_tls_verify,omitempty"` } @@ -4136,6 +4229,9 @@ type ZoneRequest struct { Name string `json:"name"` } +// Conflict Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided. +type Conflict = ErrorResponse + // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParams struct { // Page Page number diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 2c66bb946..c5581296c 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -228,6 +228,7 @@ const ( ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1 ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2 ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3 + ExposeProtocol_EXPOSE_TLS ExposeProtocol = 4 ) // Enum value maps for ExposeProtocol. @@ -237,12 +238,14 @@ var ( 1: "EXPOSE_HTTPS", 2: "EXPOSE_TCP", 3: "EXPOSE_UDP", + 4: "EXPOSE_TLS", } ExposeProtocol_value = map[string]int32{ "EXPOSE_HTTP": 0, "EXPOSE_HTTPS": 1, "EXPOSE_TCP": 2, "EXPOSE_UDP": 3, + "EXPOSE_TLS": 4, } ) @@ -4047,6 +4050,7 @@ type ExposeServiceRequest struct { UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"` Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"` NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"` + ListenPort uint32 `protobuf:"varint,8,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` } func (x *ExposeServiceRequest) Reset() { @@ -4130,14 +4134,22 @@ func (x *ExposeServiceRequest) GetNamePrefix() string { return "" } +func (x *ExposeServiceRequest) GetListenPort() uint32 { + if x != nil { + return x.ListenPort + } + return 0 +} + type ExposeServiceResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` - ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` - Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"` + ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"` + Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"` + PortAutoAssigned bool `protobuf:"varint,4,opt,name=port_auto_assigned,json=portAutoAssigned,proto3" json:"port_auto_assigned,omitempty"` } func (x *ExposeServiceResponse) Reset() { @@ -4193,6 +4205,13 @@ func (x *ExposeServiceResponse) GetDomain() string { return "" } +func (x *ExposeServiceResponse) GetPortAutoAssigned() bool { + if x != nil { + return x.PortAutoAssigned + } + return false +} + type RenewExposeRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -4996,7 +5015,7 @@ var file_management_proto_rawDesc = []byte{ 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0xea, 0x01, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, @@ -5010,15 +5029,20 @@ var file_management_proto_rawDesc = []byte{ 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x22, 0x73, - 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, + 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x1f, + 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x22, + 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, + 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x61, 0x75, + 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x73, 0x73, 0x69, 0x67, + 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, @@ -5039,12 +5063,13 @@ var file_management_proto_rawDesc = []byte{ 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, - 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, 0x53, 0x0a, 0x0e, 0x45, + 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x55, 0x44, 0x50, 0x10, 0x03, + 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index fdbe3a365..9acf7e2b3 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -652,6 +652,7 @@ enum ExposeProtocol { EXPOSE_HTTPS = 1; EXPOSE_TCP = 2; EXPOSE_UDP = 3; + EXPOSE_TLS = 4; } message ExposeServiceRequest { @@ -662,12 +663,14 @@ message ExposeServiceRequest { repeated string user_groups = 5; string domain = 6; string name_prefix = 7; + uint32 listen_port = 8; } message ExposeServiceResponse { string service_name = 1; string service_url = 2; string domain = 3; + bool port_auto_assigned = 4; } message RenewExposeRequest { diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 275e8be37..115ac5101 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v6.33.3 // source: proxy_service.proto package proto @@ -175,22 +175,72 @@ func (ProxyStatus) EnumDescriptor() ([]byte, []int) { return file_proxy_service_proto_rawDescGZIP(), []int{2} } +// ProxyCapabilities describes what a proxy can handle. +type ProxyCapabilities struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. + SupportsCustomPorts *bool `protobuf:"varint,1,opt,name=supports_custom_ports,json=supportsCustomPorts,proto3,oneof" json:"supports_custom_ports,omitempty"` +} + +func (x *ProxyCapabilities) Reset() { + *x = ProxyCapabilities{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ProxyCapabilities) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProxyCapabilities) ProtoMessage() {} + +func (x *ProxyCapabilities) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProxyCapabilities.ProtoReflect.Descriptor instead. +func (*ProxyCapabilities) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{0} +} + +func (x *ProxyCapabilities) GetSupportsCustomPorts() bool { + if x != nil && x.SupportsCustomPorts != nil { + return *x.SupportsCustomPorts + } + return false +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. type GetMappingUpdateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` - Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` - StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` - Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` + Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` + Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` } func (x *GetMappingUpdateRequest) Reset() { *x = GetMappingUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[0] + mi := &file_proxy_service_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -203,7 +253,7 @@ func (x *GetMappingUpdateRequest) String() string { func (*GetMappingUpdateRequest) ProtoMessage() {} func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[0] + mi := &file_proxy_service_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -216,7 +266,7 @@ func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMappingUpdateRequest.ProtoReflect.Descriptor instead. func (*GetMappingUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{0} + return file_proxy_service_proto_rawDescGZIP(), []int{1} } func (x *GetMappingUpdateRequest) GetProxyId() string { @@ -247,6 +297,13 @@ func (x *GetMappingUpdateRequest) GetAddress() string { return "" } +func (x *GetMappingUpdateRequest) GetCapabilities() *ProxyCapabilities { + if x != nil { + return x.Capabilities + } + return nil +} + // GetMappingUpdateResponse contains zero or more ProxyMappings. // No mappings may be sent to test the liveness of the Proxy. // Mappings that are sent should be interpreted by the Proxy appropriately. @@ -264,7 +321,7 @@ type GetMappingUpdateResponse struct { func (x *GetMappingUpdateResponse) Reset() { *x = GetMappingUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[1] + mi := &file_proxy_service_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -277,7 +334,7 @@ func (x *GetMappingUpdateResponse) String() string { func (*GetMappingUpdateResponse) ProtoMessage() {} func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[1] + mi := &file_proxy_service_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -290,7 +347,7 @@ func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMappingUpdateResponse.ProtoReflect.Descriptor instead. func (*GetMappingUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{1} + return file_proxy_service_proto_rawDescGZIP(), []int{2} } func (x *GetMappingUpdateResponse) GetMapping() []*ProxyMapping { @@ -316,12 +373,16 @@ type PathTargetOptions struct { RequestTimeout *durationpb.Duration `protobuf:"bytes,2,opt,name=request_timeout,json=requestTimeout,proto3" json:"request_timeout,omitempty"` PathRewrite PathRewriteMode `protobuf:"varint,3,opt,name=path_rewrite,json=pathRewrite,proto3,enum=management.PathRewriteMode" json:"path_rewrite,omitempty"` CustomHeaders map[string]string `protobuf:"bytes,4,rep,name=custom_headers,json=customHeaders,proto3" json:"custom_headers,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + // Send PROXY protocol v2 header to this backend. + ProxyProtocol bool `protobuf:"varint,5,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` + // Idle timeout before a UDP session is reaped. + SessionIdleTimeout *durationpb.Duration `protobuf:"bytes,6,opt,name=session_idle_timeout,json=sessionIdleTimeout,proto3" json:"session_idle_timeout,omitempty"` } func (x *PathTargetOptions) Reset() { *x = PathTargetOptions{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[2] + mi := &file_proxy_service_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -334,7 +395,7 @@ func (x *PathTargetOptions) String() string { func (*PathTargetOptions) ProtoMessage() {} func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[2] + mi := &file_proxy_service_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -347,7 +408,7 @@ func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { // Deprecated: Use PathTargetOptions.ProtoReflect.Descriptor instead. func (*PathTargetOptions) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{2} + return file_proxy_service_proto_rawDescGZIP(), []int{3} } func (x *PathTargetOptions) GetSkipTlsVerify() bool { @@ -378,6 +439,20 @@ func (x *PathTargetOptions) GetCustomHeaders() map[string]string { return nil } +func (x *PathTargetOptions) GetProxyProtocol() bool { + if x != nil { + return x.ProxyProtocol + } + return false +} + +func (x *PathTargetOptions) GetSessionIdleTimeout() *durationpb.Duration { + if x != nil { + return x.SessionIdleTimeout + } + return nil +} + type PathMapping struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -391,7 +466,7 @@ type PathMapping struct { func (x *PathMapping) Reset() { *x = PathMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[3] + mi := &file_proxy_service_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -404,7 +479,7 @@ func (x *PathMapping) String() string { func (*PathMapping) ProtoMessage() {} func (x *PathMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[3] + mi := &file_proxy_service_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -417,7 +492,7 @@ func (x *PathMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use PathMapping.ProtoReflect.Descriptor instead. func (*PathMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{3} + return file_proxy_service_proto_rawDescGZIP(), []int{4} } func (x *PathMapping) GetPath() string { @@ -456,7 +531,7 @@ type Authentication struct { func (x *Authentication) Reset() { *x = Authentication{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -469,7 +544,7 @@ func (x *Authentication) String() string { func (*Authentication) ProtoMessage() {} func (x *Authentication) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[4] + mi := &file_proxy_service_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -482,7 +557,7 @@ func (x *Authentication) ProtoReflect() protoreflect.Message { // Deprecated: Use Authentication.ProtoReflect.Descriptor instead. func (*Authentication) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{4} + return file_proxy_service_proto_rawDescGZIP(), []int{5} } func (x *Authentication) GetSessionKey() string { @@ -538,12 +613,16 @@ type ProxyMapping struct { // When true, Location headers in backend responses are rewritten to replace // the backend address with the public-facing domain. RewriteRedirects bool `protobuf:"varint,9,opt,name=rewrite_redirects,json=rewriteRedirects,proto3" json:"rewrite_redirects,omitempty"` + // Service mode: "http", "tcp", "udp", or "tls". + Mode string `protobuf:"bytes,10,opt,name=mode,proto3" json:"mode,omitempty"` + // For L4/TLS: the port the proxy listens on. + ListenPort int32 `protobuf:"varint,11,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` } func (x *ProxyMapping) Reset() { *x = ProxyMapping{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -556,7 +635,7 @@ func (x *ProxyMapping) String() string { func (*ProxyMapping) ProtoMessage() {} func (x *ProxyMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[5] + mi := &file_proxy_service_proto_msgTypes[6] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -569,7 +648,7 @@ func (x *ProxyMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use ProxyMapping.ProtoReflect.Descriptor instead. func (*ProxyMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{5} + return file_proxy_service_proto_rawDescGZIP(), []int{6} } func (x *ProxyMapping) GetType() ProxyMappingUpdateType { @@ -635,6 +714,20 @@ func (x *ProxyMapping) GetRewriteRedirects() bool { return false } +func (x *ProxyMapping) GetMode() string { + if x != nil { + return x.Mode + } + return "" +} + +func (x *ProxyMapping) GetListenPort() int32 { + if x != nil { + return x.ListenPort + } + return 0 +} + // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. type SendAccessLogRequest struct { state protoimpl.MessageState @@ -647,7 +740,7 @@ type SendAccessLogRequest struct { func (x *SendAccessLogRequest) Reset() { *x = SendAccessLogRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -660,7 +753,7 @@ func (x *SendAccessLogRequest) String() string { func (*SendAccessLogRequest) ProtoMessage() {} func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[6] + mi := &file_proxy_service_proto_msgTypes[7] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -673,7 +766,7 @@ func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogRequest.ProtoReflect.Descriptor instead. func (*SendAccessLogRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{6} + return file_proxy_service_proto_rawDescGZIP(), []int{7} } func (x *SendAccessLogRequest) GetLog() *AccessLog { @@ -693,7 +786,7 @@ type SendAccessLogResponse struct { func (x *SendAccessLogResponse) Reset() { *x = SendAccessLogResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -706,7 +799,7 @@ func (x *SendAccessLogResponse) String() string { func (*SendAccessLogResponse) ProtoMessage() {} func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[7] + mi := &file_proxy_service_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -719,7 +812,7 @@ func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogResponse.ProtoReflect.Descriptor instead. func (*SendAccessLogResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{7} + return file_proxy_service_proto_rawDescGZIP(), []int{8} } type AccessLog struct { @@ -742,12 +835,13 @@ type AccessLog struct { AuthSuccess bool `protobuf:"varint,13,opt,name=auth_success,json=authSuccess,proto3" json:"auth_success,omitempty"` BytesUpload int64 `protobuf:"varint,14,opt,name=bytes_upload,json=bytesUpload,proto3" json:"bytes_upload,omitempty"` BytesDownload int64 `protobuf:"varint,15,opt,name=bytes_download,json=bytesDownload,proto3" json:"bytes_download,omitempty"` + Protocol string `protobuf:"bytes,16,opt,name=protocol,proto3" json:"protocol,omitempty"` } func (x *AccessLog) Reset() { *x = AccessLog{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -760,7 +854,7 @@ func (x *AccessLog) String() string { func (*AccessLog) ProtoMessage() {} func (x *AccessLog) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[8] + mi := &file_proxy_service_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -773,7 +867,7 @@ func (x *AccessLog) ProtoReflect() protoreflect.Message { // Deprecated: Use AccessLog.ProtoReflect.Descriptor instead. func (*AccessLog) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{8} + return file_proxy_service_proto_rawDescGZIP(), []int{9} } func (x *AccessLog) GetTimestamp() *timestamppb.Timestamp { @@ -881,6 +975,13 @@ func (x *AccessLog) GetBytesDownload() int64 { return 0 } +func (x *AccessLog) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + type AuthenticateRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -898,7 +999,7 @@ type AuthenticateRequest struct { func (x *AuthenticateRequest) Reset() { *x = AuthenticateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -911,7 +1012,7 @@ func (x *AuthenticateRequest) String() string { func (*AuthenticateRequest) ProtoMessage() {} func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[9] + mi := &file_proxy_service_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -924,7 +1025,7 @@ func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateRequest.ProtoReflect.Descriptor instead. func (*AuthenticateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{9} + return file_proxy_service_proto_rawDescGZIP(), []int{10} } func (x *AuthenticateRequest) GetId() string { @@ -989,7 +1090,7 @@ type PasswordRequest struct { func (x *PasswordRequest) Reset() { *x = PasswordRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1002,7 +1103,7 @@ func (x *PasswordRequest) String() string { func (*PasswordRequest) ProtoMessage() {} func (x *PasswordRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[10] + mi := &file_proxy_service_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1015,7 +1116,7 @@ func (x *PasswordRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PasswordRequest.ProtoReflect.Descriptor instead. func (*PasswordRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{10} + return file_proxy_service_proto_rawDescGZIP(), []int{11} } func (x *PasswordRequest) GetPassword() string { @@ -1036,7 +1137,7 @@ type PinRequest struct { func (x *PinRequest) Reset() { *x = PinRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1049,7 +1150,7 @@ func (x *PinRequest) String() string { func (*PinRequest) ProtoMessage() {} func (x *PinRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[11] + mi := &file_proxy_service_proto_msgTypes[12] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1062,7 +1163,7 @@ func (x *PinRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PinRequest.ProtoReflect.Descriptor instead. func (*PinRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{11} + return file_proxy_service_proto_rawDescGZIP(), []int{12} } func (x *PinRequest) GetPin() string { @@ -1084,7 +1185,7 @@ type AuthenticateResponse struct { func (x *AuthenticateResponse) Reset() { *x = AuthenticateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1097,7 +1198,7 @@ func (x *AuthenticateResponse) String() string { func (*AuthenticateResponse) ProtoMessage() {} func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[12] + mi := &file_proxy_service_proto_msgTypes[13] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1110,7 +1211,7 @@ func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateResponse.ProtoReflect.Descriptor instead. func (*AuthenticateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{12} + return file_proxy_service_proto_rawDescGZIP(), []int{13} } func (x *AuthenticateResponse) GetSuccess() bool { @@ -1143,7 +1244,7 @@ type SendStatusUpdateRequest struct { func (x *SendStatusUpdateRequest) Reset() { *x = SendStatusUpdateRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1156,7 +1257,7 @@ func (x *SendStatusUpdateRequest) String() string { func (*SendStatusUpdateRequest) ProtoMessage() {} func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[13] + mi := &file_proxy_service_proto_msgTypes[14] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1169,7 +1270,7 @@ func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateRequest.ProtoReflect.Descriptor instead. func (*SendStatusUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{13} + return file_proxy_service_proto_rawDescGZIP(), []int{14} } func (x *SendStatusUpdateRequest) GetServiceId() string { @@ -1217,7 +1318,7 @@ type SendStatusUpdateResponse struct { func (x *SendStatusUpdateResponse) Reset() { *x = SendStatusUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1230,7 +1331,7 @@ func (x *SendStatusUpdateResponse) String() string { func (*SendStatusUpdateResponse) ProtoMessage() {} func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[14] + mi := &file_proxy_service_proto_msgTypes[15] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1243,7 +1344,7 @@ func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateResponse.ProtoReflect.Descriptor instead. func (*SendStatusUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{14} + return file_proxy_service_proto_rawDescGZIP(), []int{15} } // CreateProxyPeerRequest is sent by the proxy to create a peer connection @@ -1263,7 +1364,7 @@ type CreateProxyPeerRequest struct { func (x *CreateProxyPeerRequest) Reset() { *x = CreateProxyPeerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1276,7 +1377,7 @@ func (x *CreateProxyPeerRequest) String() string { func (*CreateProxyPeerRequest) ProtoMessage() {} func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[15] + mi := &file_proxy_service_proto_msgTypes[16] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1289,7 +1390,7 @@ func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerRequest.ProtoReflect.Descriptor instead. func (*CreateProxyPeerRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{15} + return file_proxy_service_proto_rawDescGZIP(), []int{16} } func (x *CreateProxyPeerRequest) GetServiceId() string { @@ -1340,7 +1441,7 @@ type CreateProxyPeerResponse struct { func (x *CreateProxyPeerResponse) Reset() { *x = CreateProxyPeerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1353,7 +1454,7 @@ func (x *CreateProxyPeerResponse) String() string { func (*CreateProxyPeerResponse) ProtoMessage() {} func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[16] + mi := &file_proxy_service_proto_msgTypes[17] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1366,7 +1467,7 @@ func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerResponse.ProtoReflect.Descriptor instead. func (*CreateProxyPeerResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{16} + return file_proxy_service_proto_rawDescGZIP(), []int{17} } func (x *CreateProxyPeerResponse) GetSuccess() bool { @@ -1396,7 +1497,7 @@ type GetOIDCURLRequest struct { func (x *GetOIDCURLRequest) Reset() { *x = GetOIDCURLRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1409,7 +1510,7 @@ func (x *GetOIDCURLRequest) String() string { func (*GetOIDCURLRequest) ProtoMessage() {} func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[17] + mi := &file_proxy_service_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1422,7 +1523,7 @@ func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLRequest.ProtoReflect.Descriptor instead. func (*GetOIDCURLRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{17} + return file_proxy_service_proto_rawDescGZIP(), []int{18} } func (x *GetOIDCURLRequest) GetId() string { @@ -1457,7 +1558,7 @@ type GetOIDCURLResponse struct { func (x *GetOIDCURLResponse) Reset() { *x = GetOIDCURLResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1470,7 +1571,7 @@ func (x *GetOIDCURLResponse) String() string { func (*GetOIDCURLResponse) ProtoMessage() {} func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1483,7 +1584,7 @@ func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLResponse.ProtoReflect.Descriptor instead. func (*GetOIDCURLResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{18} + return file_proxy_service_proto_rawDescGZIP(), []int{19} } func (x *GetOIDCURLResponse) GetUrl() string { @@ -1505,7 +1606,7 @@ type ValidateSessionRequest struct { func (x *ValidateSessionRequest) Reset() { *x = ValidateSessionRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1518,7 +1619,7 @@ func (x *ValidateSessionRequest) String() string { func (*ValidateSessionRequest) ProtoMessage() {} func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1531,7 +1632,7 @@ func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionRequest.ProtoReflect.Descriptor instead. func (*ValidateSessionRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{19} + return file_proxy_service_proto_rawDescGZIP(), []int{20} } func (x *ValidateSessionRequest) GetDomain() string { @@ -1562,7 +1663,7 @@ type ValidateSessionResponse struct { func (x *ValidateSessionResponse) Reset() { *x = ValidateSessionResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1575,7 +1676,7 @@ func (x *ValidateSessionResponse) String() string { func (*ValidateSessionResponse) ProtoMessage() {} func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1588,7 +1689,7 @@ func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionResponse.ProtoReflect.Descriptor instead. func (*ValidateSessionResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{20} + return file_proxy_service_proto_rawDescGZIP(), []int{21} } func (x *ValidateSessionResponse) GetValid() bool { @@ -1628,124 +1729,147 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x22, 0xa3, 0x01, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, - 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, - 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, - 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, - 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, - 0x52, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, - 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, - 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, - 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xda, 0x02, - 0x0a, 0x11, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, - 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, - 0x69, 0x70, 0x54, 0x6c, 0x73, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x0e, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, - 0x3e, 0x0a, 0x0c, 0x70, 0x61, 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, - 0x64, 0x65, 0x52, 0x0b, 0x70, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, - 0x57, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, - 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, - 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, - 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, - 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, - 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xaa, - 0x01, 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, - 0x65, 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, - 0x67, 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, - 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, - 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x22, 0xe0, 0x02, 0x0a, 0x0c, - 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, - 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, - 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, - 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, - 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, - 0x74, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, - 0x61, 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, - 0x74, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, - 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, - 0x72, 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, - 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, - 0x77, 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x22, 0x3f, - 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, 0x67, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, 0x6c, 0x6f, 0x67, 0x22, - 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xea, 0x03, 0x0a, 0x09, 0x41, 0x63, 0x63, - 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, - 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, - 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, - 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x73, 0x12, 0x16, - 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, - 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x72, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x12, - 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x75, 0x74, 0x68, - 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, - 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, - 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x25, - 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, - 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, 0x44, 0x6f, 0x77, - 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, + 0x74, 0x6f, 0x22, 0x66, 0x0a, 0x11, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, + 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x15, 0x73, 0x75, 0x70, 0x70, 0x6f, + 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x13, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x6f, 0x72, 0x74, 0x73, 0x88, 0x01, 0x01, + 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x22, 0xe6, 0x01, 0x0a, 0x17, 0x47, + 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, + 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, + 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, + 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, + 0x69, 0x65, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, + 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, 0x70, + 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, + 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, + 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xce, 0x03, 0x0a, 0x11, 0x50, 0x61, 0x74, + 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x26, + 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, 0x73, + 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, 0x72, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x3e, 0x0a, 0x0c, 0x70, 0x61, + 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, + 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x0b, 0x70, + 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, 0x57, 0x0a, 0x0e, 0x63, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, + 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, + 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x70, 0x72, 0x6f, + 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x4b, 0x0a, 0x14, 0x73, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x6c, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, + 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, 0x65, + 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, 0x74, + 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, + 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x61, + 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xaa, 0x01, + 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, + 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, + 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, + 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, + 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, + 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x22, 0x95, 0x03, 0x0a, 0x0c, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, 0x74, + 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, + 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, 0x61, + 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, 0x74, + 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, 0x68, + 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, 0x77, + 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x12, 0x12, 0x0a, + 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6d, 0x6f, 0x64, + 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, + 0x72, 0x74, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, + 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x86, 0x04, 0x0a, + 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, + 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, + 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, + 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, + 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, + 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, + 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, + 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, + 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, @@ -1907,70 +2031,73 @@ func file_proxy_service_proto_rawDescGZIP() []byte { } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 22) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 23) var file_proxy_service_proto_goTypes = []interface{}{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType (PathRewriteMode)(0), // 1: management.PathRewriteMode (ProxyStatus)(0), // 2: management.ProxyStatus - (*GetMappingUpdateRequest)(nil), // 3: management.GetMappingUpdateRequest - (*GetMappingUpdateResponse)(nil), // 4: management.GetMappingUpdateResponse - (*PathTargetOptions)(nil), // 5: management.PathTargetOptions - (*PathMapping)(nil), // 6: management.PathMapping - (*Authentication)(nil), // 7: management.Authentication - (*ProxyMapping)(nil), // 8: management.ProxyMapping - (*SendAccessLogRequest)(nil), // 9: management.SendAccessLogRequest - (*SendAccessLogResponse)(nil), // 10: management.SendAccessLogResponse - (*AccessLog)(nil), // 11: management.AccessLog - (*AuthenticateRequest)(nil), // 12: management.AuthenticateRequest - (*PasswordRequest)(nil), // 13: management.PasswordRequest - (*PinRequest)(nil), // 14: management.PinRequest - (*AuthenticateResponse)(nil), // 15: management.AuthenticateResponse - (*SendStatusUpdateRequest)(nil), // 16: management.SendStatusUpdateRequest - (*SendStatusUpdateResponse)(nil), // 17: management.SendStatusUpdateResponse - (*CreateProxyPeerRequest)(nil), // 18: management.CreateProxyPeerRequest - (*CreateProxyPeerResponse)(nil), // 19: management.CreateProxyPeerResponse - (*GetOIDCURLRequest)(nil), // 20: management.GetOIDCURLRequest - (*GetOIDCURLResponse)(nil), // 21: management.GetOIDCURLResponse - (*ValidateSessionRequest)(nil), // 22: management.ValidateSessionRequest - (*ValidateSessionResponse)(nil), // 23: management.ValidateSessionResponse - nil, // 24: management.PathTargetOptions.CustomHeadersEntry - (*timestamppb.Timestamp)(nil), // 25: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 26: google.protobuf.Duration + (*ProxyCapabilities)(nil), // 3: management.ProxyCapabilities + (*GetMappingUpdateRequest)(nil), // 4: management.GetMappingUpdateRequest + (*GetMappingUpdateResponse)(nil), // 5: management.GetMappingUpdateResponse + (*PathTargetOptions)(nil), // 6: management.PathTargetOptions + (*PathMapping)(nil), // 7: management.PathMapping + (*Authentication)(nil), // 8: management.Authentication + (*ProxyMapping)(nil), // 9: management.ProxyMapping + (*SendAccessLogRequest)(nil), // 10: management.SendAccessLogRequest + (*SendAccessLogResponse)(nil), // 11: management.SendAccessLogResponse + (*AccessLog)(nil), // 12: management.AccessLog + (*AuthenticateRequest)(nil), // 13: management.AuthenticateRequest + (*PasswordRequest)(nil), // 14: management.PasswordRequest + (*PinRequest)(nil), // 15: management.PinRequest + (*AuthenticateResponse)(nil), // 16: management.AuthenticateResponse + (*SendStatusUpdateRequest)(nil), // 17: management.SendStatusUpdateRequest + (*SendStatusUpdateResponse)(nil), // 18: management.SendStatusUpdateResponse + (*CreateProxyPeerRequest)(nil), // 19: management.CreateProxyPeerRequest + (*CreateProxyPeerResponse)(nil), // 20: management.CreateProxyPeerResponse + (*GetOIDCURLRequest)(nil), // 21: management.GetOIDCURLRequest + (*GetOIDCURLResponse)(nil), // 22: management.GetOIDCURLResponse + (*ValidateSessionRequest)(nil), // 23: management.ValidateSessionRequest + (*ValidateSessionResponse)(nil), // 24: management.ValidateSessionResponse + nil, // 25: management.PathTargetOptions.CustomHeadersEntry + (*timestamppb.Timestamp)(nil), // 26: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 27: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 25, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp - 8, // 1: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 26, // 2: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration - 1, // 3: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode - 24, // 4: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 5, // 5: management.PathMapping.options:type_name -> management.PathTargetOptions - 0, // 6: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType - 6, // 7: management.ProxyMapping.path:type_name -> management.PathMapping - 7, // 8: management.ProxyMapping.auth:type_name -> management.Authentication - 11, // 9: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 25, // 10: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 13, // 11: management.AuthenticateRequest.password:type_name -> management.PasswordRequest - 14, // 12: management.AuthenticateRequest.pin:type_name -> management.PinRequest - 2, // 13: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 3, // 14: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 9, // 15: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 12, // 16: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 16, // 17: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 18, // 18: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 20, // 19: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 22, // 20: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 4, // 21: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 10, // 22: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 15, // 23: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 17, // 24: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 19, // 25: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 21, // 26: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 23, // 27: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 21, // [21:28] is the sub-list for method output_type - 14, // [14:21] is the sub-list for method input_type - 14, // [14:14] is the sub-list for extension type_name - 14, // [14:14] is the sub-list for extension extendee - 0, // [0:14] is the sub-list for field type_name + 26, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities + 9, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping + 27, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode + 25, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 27, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions + 0, // 8: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType + 7, // 9: management.ProxyMapping.path:type_name -> management.PathMapping + 8, // 10: management.ProxyMapping.auth:type_name -> management.Authentication + 12, // 11: management.SendAccessLogRequest.log:type_name -> management.AccessLog + 26, // 12: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 14, // 13: management.AuthenticateRequest.password:type_name -> management.PasswordRequest + 15, // 14: management.AuthenticateRequest.pin:type_name -> management.PinRequest + 2, // 15: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus + 4, // 16: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 10, // 17: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 13, // 18: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 17, // 19: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 19, // 20: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 21, // 21: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 23, // 22: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 5, // 23: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 11, // 24: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 16, // 25: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 18, // 26: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 20, // 27: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 22, // 28: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 24, // 29: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 23, // [23:30] is the sub-list for method output_type + 16, // [16:23] is the sub-list for method input_type + 16, // [16:16] is the sub-list for extension type_name + 16, // [16:16] is the sub-list for extension extendee + 0, // [0:16] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -1980,7 +2107,7 @@ func file_proxy_service_proto_init() { } if !protoimpl.UnsafeEnabled { file_proxy_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateRequest); i { + switch v := v.(*ProxyCapabilities); i { case 0: return &v.state case 1: @@ -1992,7 +2119,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateResponse); i { + switch v := v.(*GetMappingUpdateRequest); i { case 0: return &v.state case 1: @@ -2004,7 +2131,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathTargetOptions); i { + switch v := v.(*GetMappingUpdateResponse); i { case 0: return &v.state case 1: @@ -2016,7 +2143,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathMapping); i { + switch v := v.(*PathTargetOptions); i { case 0: return &v.state case 1: @@ -2028,7 +2155,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Authentication); i { + switch v := v.(*PathMapping); i { case 0: return &v.state case 1: @@ -2040,7 +2167,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProxyMapping); i { + switch v := v.(*Authentication); i { case 0: return &v.state case 1: @@ -2052,7 +2179,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogRequest); i { + switch v := v.(*ProxyMapping); i { case 0: return &v.state case 1: @@ -2064,7 +2191,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogResponse); i { + switch v := v.(*SendAccessLogRequest); i { case 0: return &v.state case 1: @@ -2076,7 +2203,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AccessLog); i { + switch v := v.(*SendAccessLogResponse); i { case 0: return &v.state case 1: @@ -2088,7 +2215,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateRequest); i { + switch v := v.(*AccessLog); i { case 0: return &v.state case 1: @@ -2100,7 +2227,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PasswordRequest); i { + switch v := v.(*AuthenticateRequest); i { case 0: return &v.state case 1: @@ -2112,7 +2239,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PinRequest); i { + switch v := v.(*PasswordRequest); i { case 0: return &v.state case 1: @@ -2124,7 +2251,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateResponse); i { + switch v := v.(*PinRequest); i { case 0: return &v.state case 1: @@ -2136,7 +2263,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateRequest); i { + switch v := v.(*AuthenticateResponse); i { case 0: return &v.state case 1: @@ -2148,7 +2275,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateResponse); i { + switch v := v.(*SendStatusUpdateRequest); i { case 0: return &v.state case 1: @@ -2160,7 +2287,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerRequest); i { + switch v := v.(*SendStatusUpdateResponse); i { case 0: return &v.state case 1: @@ -2172,7 +2299,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerResponse); i { + switch v := v.(*CreateProxyPeerRequest); i { case 0: return &v.state case 1: @@ -2184,7 +2311,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLRequest); i { + switch v := v.(*CreateProxyPeerResponse); i { case 0: return &v.state case 1: @@ -2196,7 +2323,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLResponse); i { + switch v := v.(*GetOIDCURLRequest); i { case 0: return &v.state case 1: @@ -2208,7 +2335,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionRequest); i { + switch v := v.(*GetOIDCURLResponse); i { case 0: return &v.state case 1: @@ -2220,6 +2347,18 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ValidateSessionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*ValidateSessionResponse); i { case 0: return &v.state @@ -2232,19 +2371,20 @@ func file_proxy_service_proto_init() { } } } - file_proxy_service_proto_msgTypes[9].OneofWrappers = []interface{}{ + file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[10].OneofWrappers = []interface{}{ (*AuthenticateRequest_Password)(nil), (*AuthenticateRequest_Pin)(nil), } - file_proxy_service_proto_msgTypes[13].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[16].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[14].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, NumEnums: 3, - NumMessages: 22, + NumMessages: 23, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index 195b60f01..457d12e85 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -27,12 +27,19 @@ service ProxyService { rpc ValidateSession(ValidateSessionRequest) returns (ValidateSessionResponse); } +// ProxyCapabilities describes what a proxy can handle. +message ProxyCapabilities { + // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. + optional bool supports_custom_ports = 1; +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. message GetMappingUpdateRequest { string proxy_id = 1; string version = 2; google.protobuf.Timestamp started_at = 3; string address = 4; + ProxyCapabilities capabilities = 5; } // GetMappingUpdateResponse contains zero or more ProxyMappings. @@ -61,6 +68,10 @@ message PathTargetOptions { google.protobuf.Duration request_timeout = 2; PathRewriteMode path_rewrite = 3; map custom_headers = 4; + // Send PROXY protocol v2 header to this backend. + bool proxy_protocol = 5; + // Idle timeout before a UDP session is reaped. + google.protobuf.Duration session_idle_timeout = 6; } message PathMapping { @@ -91,6 +102,10 @@ message ProxyMapping { // When true, Location headers in backend responses are rewritten to replace // the backend address with the public-facing domain. bool rewrite_redirects = 9; + // Service mode: "http", "tcp", "udp", or "tls". + string mode = 10; + // For L4/TLS: the port the proxy listens on. + int32 listen_port = 11; } // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. @@ -117,6 +132,7 @@ message AccessLog { bool auth_success = 13; int64 bytes_upload = 14; int64 bytes_download = 15; + string protocol = 16; } message AuthenticateRequest { From 387e374e4b3250c88c263f994c582872922ceb03 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 16 Mar 2026 22:22:00 +0800 Subject: [PATCH 22/27] [proxy, management] Add header auth, access restrictions, and session idle timeout (#5587) --- .../reverseproxy/accesslogs/accesslogentry.go | 73 +- .../accesslogs/manager/manager.go | 3 + .../reverseproxy/service/manager/manager.go | 60 +- .../modules/reverseproxy/service/service.go | 340 +++- .../reverseproxy/service/service_test.go | 2 +- management/internals/shared/grpc/proxy.go | 69 +- management/server/geolocation/geolocation.go | 6 + proxy/Dockerfile | 2 +- proxy/Dockerfile.multistage | 2 +- proxy/auth/auth.go | 3 +- proxy/cmd/proxy/cmd/root.go | 69 +- proxy/internal/accesslog/logger.go | 147 +- proxy/internal/accesslog/middleware.go | 12 +- proxy/internal/auth/header.go | 69 + proxy/internal/auth/middleware.go | 241 ++- proxy/internal/auth/middleware_test.go | 380 +++- proxy/internal/geolocation/download.go | 264 +++ proxy/internal/geolocation/geolocation.go | 152 ++ proxy/internal/proxy/context.go | 110 +- proxy/internal/proxy/reverseproxy.go | 27 +- proxy/internal/proxy/reverseproxy_test.go | 103 +- proxy/internal/proxy/servicemapping.go | 29 +- proxy/internal/restrict/restrict.go | 183 ++ proxy/internal/restrict/restrict_test.go | 278 +++ proxy/internal/tcp/router.go | 122 +- proxy/internal/tcp/router_test.go | 71 + proxy/internal/udp/relay.go | 100 +- proxy/management_integration_test.go | 3 +- proxy/server.go | 188 +- shared/management/http/api/openapi.yml | 109 +- shared/management/http/api/types.gen.go | 41 +- shared/management/proto/proxy_service.pb.go | 1605 +++++++---------- shared/management/proto/proxy_service.proto | 22 + shared/relay/client/early_msg_buffer.go | 4 +- 34 files changed, 3509 insertions(+), 1380 deletions(-) create mode 100644 proxy/internal/auth/header.go create mode 100644 proxy/internal/geolocation/download.go create mode 100644 proxy/internal/geolocation/geolocation.go create mode 100644 proxy/internal/restrict/restrict.go create mode 100644 proxy/internal/restrict/restrict_test.go diff --git a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go index 619a34684..a7f692569 100644 --- a/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go +++ b/management/internals/modules/reverseproxy/accesslogs/accesslogentry.go @@ -20,22 +20,23 @@ const ( ) type AccessLogEntry struct { - ID string `gorm:"primaryKey"` - AccountID string `gorm:"index"` - ServiceID string `gorm:"index"` - Timestamp time.Time `gorm:"index"` - GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"` - Method string `gorm:"index"` - Host string `gorm:"index"` - Path string `gorm:"index"` - Duration time.Duration `gorm:"index"` - StatusCode int `gorm:"index"` - Reason string - UserId string `gorm:"index"` - AuthMethodUsed string `gorm:"index"` - BytesUpload int64 `gorm:"index"` - BytesDownload int64 `gorm:"index"` - Protocol AccessLogProtocol `gorm:"index"` + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + ServiceID string `gorm:"index"` + Timestamp time.Time `gorm:"index"` + GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"` + SubdivisionCode string + Method string `gorm:"index"` + Host string `gorm:"index"` + Path string `gorm:"index"` + Duration time.Duration `gorm:"index"` + StatusCode int `gorm:"index"` + Reason string + UserId string `gorm:"index"` + AuthMethodUsed string `gorm:"index"` + BytesUpload int64 `gorm:"index"` + BytesDownload int64 `gorm:"index"` + Protocol AccessLogProtocol `gorm:"index"` } // FromProto creates an AccessLogEntry from a proto.AccessLog @@ -105,6 +106,11 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { cityName = &a.GeoLocation.CityName } + var subdivisionCode *string + if a.SubdivisionCode != "" { + subdivisionCode = &a.SubdivisionCode + } + var protocol *string if a.Protocol != "" { p := string(a.Protocol) @@ -112,22 +118,23 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog { } return &api.ProxyAccessLog{ - Id: a.ID, - ServiceId: a.ServiceID, - Timestamp: a.Timestamp, - Method: a.Method, - Host: a.Host, - Path: a.Path, - DurationMs: int(a.Duration.Milliseconds()), - StatusCode: a.StatusCode, - SourceIp: sourceIP, - Reason: reason, - UserId: userID, - AuthMethodUsed: authMethod, - CountryCode: countryCode, - CityName: cityName, - BytesUpload: a.BytesUpload, - BytesDownload: a.BytesDownload, - Protocol: protocol, + Id: a.ID, + ServiceId: a.ServiceID, + Timestamp: a.Timestamp, + Method: a.Method, + Host: a.Host, + Path: a.Path, + DurationMs: int(a.Duration.Milliseconds()), + StatusCode: a.StatusCode, + SourceIp: sourceIP, + Reason: reason, + UserId: userID, + AuthMethodUsed: authMethod, + CountryCode: countryCode, + CityName: cityName, + SubdivisionCode: subdivisionCode, + BytesUpload: a.BytesUpload, + BytesDownload: a.BytesDownload, + Protocol: protocol, } } diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go index e7fba7bed..e8d0ce763 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go @@ -41,6 +41,9 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac logEntry.GeoLocation.CountryCode = location.Country.ISOCode logEntry.GeoLocation.CityName = location.City.Names.En logEntry.GeoLocation.GeoNameID = location.City.GeonameID + if len(location.Subdivisions) > 0 { + logEntry.SubdivisionCode = location.Subdivisions[0].ISOCode + } } } diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index c40961fdc..65177bf5d 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand/v2" + "net/http" "os" "slices" "strconv" @@ -229,6 +230,12 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri return fmt.Errorf("hash secrets: %w", err) } + for i, h := range service.Auth.HeaderAuths { + if h != nil && h.Enabled && h.Value == "" { + return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i) + } + } + keyPair, err := sessionkey.GenerateKeyPair() if err != nil { return fmt.Errorf("generate session keys: %w", err) @@ -488,6 +495,9 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se } m.preserveExistingAuthSecrets(service, existingService) + if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil { + return err + } m.preserveServiceMetadata(service, existingService) m.preserveListenPort(service, existingService) updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled @@ -544,18 +554,52 @@ func isHTTPFamily(mode string) bool { return mode == "" || mode == "http" } -func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) { - if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && +func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) { + if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && - service.Auth.PasswordAuth.Password == "" { - service.Auth.PasswordAuth = existingService.Auth.PasswordAuth + svc.Auth.PasswordAuth.Password == "" { + svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth } - if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled && + if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled && existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled && - service.Auth.PinAuth.Pin == "" { - service.Auth.PinAuth = existingService.Auth.PinAuth + svc.Auth.PinAuth.Pin == "" { + svc.Auth.PinAuth = existingService.Auth.PinAuth } + + preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths) +} + +// preserveHeaderAuthHashes fills in empty header auth values from the existing +// service so that unchanged secrets are not lost on update. +func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) { + if len(headers) == 0 || len(existing) == 0 { + return + } + existingByHeader := make(map[string]string, len(existing)) + for _, h := range existing { + if h != nil && h.Value != "" { + existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value + } + } + for _, h := range headers { + if h != nil && h.Enabled && h.Value == "" { + if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok { + h.Value = hash + } + } + } +} + +// validateHeaderAuthValues checks that all enabled header auths have a value +// (either freshly provided or preserved from the existing service). +func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error { + for i, h := range headers { + if h != nil && h.Enabled && h.Value == "" { + return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i) + } + } + return nil } func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) { @@ -605,6 +649,8 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco } return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) } + default: + return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId) } } return nil diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 623284404..6c7c80806 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -7,14 +7,15 @@ import ( "math/big" "net" "net/http" + "net/netip" "net/url" "regexp" + "slices" "strconv" "strings" "time" "github.com/rs/xid" - log "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" @@ -91,10 +92,37 @@ type BearerAuthConfig struct { DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"` } +// HeaderAuthConfig defines a static header-value auth check. +// The proxy compares the incoming header value against the stored hash. +type HeaderAuthConfig struct { + Enabled bool `json:"enabled"` + Header string `json:"header"` + Value string `json:"value"` +} + type AuthConfig struct { PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"` PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"` BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"` + HeaderAuths []*HeaderAuthConfig `json:"header_auths,omitempty" gorm:"serializer:json"` +} + +// AccessRestrictions controls who can connect to the service based on IP or geography. +type AccessRestrictions struct { + AllowedCIDRs []string `json:"allowed_cidrs,omitempty" gorm:"serializer:json"` + BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"` + AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"` + BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"` +} + +// Copy returns a deep copy of the AccessRestrictions. +func (r AccessRestrictions) Copy() AccessRestrictions { + return AccessRestrictions{ + AllowedCIDRs: slices.Clone(r.AllowedCIDRs), + BlockedCIDRs: slices.Clone(r.BlockedCIDRs), + AllowedCountries: slices.Clone(r.AllowedCountries), + BlockedCountries: slices.Clone(r.BlockedCountries), + } } func (a *AuthConfig) HashSecrets() error { @@ -114,6 +142,16 @@ func (a *AuthConfig) HashSecrets() error { a.PinAuth.Pin = hashedPin } + for i, h := range a.HeaderAuths { + if h != nil && h.Enabled && h.Value != "" { + hashedValue, err := argon2id.Hash(h.Value) + if err != nil { + return fmt.Errorf("hash header auth[%d] value: %w", i, err) + } + h.Value = hashedValue + } + } + return nil } @@ -124,6 +162,11 @@ func (a *AuthConfig) ClearSecrets() { if a.PinAuth != nil { a.PinAuth.Pin = "" } + for _, h := range a.HeaderAuths { + if h != nil { + h.Value = "" + } + } } type Meta struct { @@ -143,12 +186,13 @@ type Service struct { Enabled bool PassHostHeader bool RewriteRedirects bool - Auth AuthConfig `gorm:"serializer:json"` - Meta Meta `gorm:"embedded;embeddedPrefix:meta_"` - SessionPrivateKey string `gorm:"column:session_private_key"` - SessionPublicKey string `gorm:"column:session_public_key"` - Source string `gorm:"default:'permanent';index:idx_service_source_peer"` - SourcePeer string `gorm:"index:idx_service_source_peer"` + Auth AuthConfig `gorm:"serializer:json"` + Restrictions AccessRestrictions `gorm:"serializer:json"` + Meta Meta `gorm:"embedded;embeddedPrefix:meta_"` + SessionPrivateKey string `gorm:"column:session_private_key"` + SessionPublicKey string `gorm:"column:session_public_key"` + Source string `gorm:"default:'permanent';index:idx_service_source_peer"` + SourcePeer string `gorm:"index:idx_service_source_peer"` // Mode determines the service type: "http", "tcp", "udp", or "tls". Mode string `gorm:"default:'http'"` ListenPort uint16 @@ -188,6 +232,20 @@ func (s *Service) ToAPIResponse() *api.Service { } } + if len(s.Auth.HeaderAuths) > 0 { + apiHeaders := make([]api.HeaderAuthConfig, 0, len(s.Auth.HeaderAuths)) + for _, h := range s.Auth.HeaderAuths { + if h == nil { + continue + } + apiHeaders = append(apiHeaders, api.HeaderAuthConfig{ + Enabled: h.Enabled, + Header: h.Header, + }) + } + authConfig.HeaderAuths = &apiHeaders + } + // Convert internal targets to API targets apiTargets := make([]api.ServiceTarget, 0, len(s.Targets)) for _, target := range s.Targets { @@ -222,18 +280,19 @@ func (s *Service) ToAPIResponse() *api.Service { listenPort := int(s.ListenPort) resp := &api.Service{ - Id: s.ID, - Name: s.Name, - Domain: s.Domain, - Targets: apiTargets, - Enabled: s.Enabled, - PassHostHeader: &s.PassHostHeader, - RewriteRedirects: &s.RewriteRedirects, - Auth: authConfig, - Meta: meta, - Mode: &mode, - ListenPort: &listenPort, - PortAutoAssigned: &s.PortAutoAssigned, + Id: s.ID, + Name: s.Name, + Domain: s.Domain, + Targets: apiTargets, + Enabled: s.Enabled, + PassHostHeader: &s.PassHostHeader, + RewriteRedirects: &s.RewriteRedirects, + Auth: authConfig, + AccessRestrictions: restrictionsToAPI(s.Restrictions), + Meta: meta, + Mode: &mode, + ListenPort: &listenPort, + PortAutoAssigned: &s.PortAutoAssigned, } if s.ProxyCluster != "" { @@ -263,7 +322,16 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf auth.Oidc = true } - return &proto.ProxyMapping{ + for _, h := range s.Auth.HeaderAuths { + if h != nil && h.Enabled { + auth.HeaderAuths = append(auth.HeaderAuths, &proto.HeaderAuth{ + Header: h.Header, + HashedValue: h.Value, + }) + } + } + + mapping := &proto.ProxyMapping{ Type: operationToProtoType(operation), Id: s.ID, Domain: s.Domain, @@ -276,6 +344,12 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf Mode: s.Mode, ListenPort: int32(s.ListenPort), //nolint:gosec } + + if r := restrictionsToProto(s.Restrictions); r != nil { + mapping.AccessRestrictions = r + } + + return mapping } // buildPathMappings constructs PathMapping entries from targets. @@ -334,8 +408,7 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { case Delete: return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED default: - log.Fatalf("unknown operation type: %v", op) - return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED + panic(fmt.Sprintf("unknown operation type: %v", op)) } } @@ -477,6 +550,10 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro s.Auth = authFromAPI(req.Auth) } + if req.AccessRestrictions != nil { + s.Restrictions = restrictionsFromAPI(req.AccessRestrictions) + } + return nil } @@ -538,9 +615,70 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig { } auth.BearerAuth = bearerAuth } + if reqAuth.HeaderAuths != nil { + for _, h := range *reqAuth.HeaderAuths { + auth.HeaderAuths = append(auth.HeaderAuths, &HeaderAuthConfig{ + Enabled: h.Enabled, + Header: h.Header, + Value: h.Value, + }) + } + } return auth } +func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions { + if r == nil { + return AccessRestrictions{} + } + var res AccessRestrictions + if r.AllowedCidrs != nil { + res.AllowedCIDRs = *r.AllowedCidrs + } + if r.BlockedCidrs != nil { + res.BlockedCIDRs = *r.BlockedCidrs + } + if r.AllowedCountries != nil { + res.AllowedCountries = *r.AllowedCountries + } + if r.BlockedCountries != nil { + res.BlockedCountries = *r.BlockedCountries + } + return res +} + +func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions { + if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 { + return nil + } + res := &api.AccessRestrictions{} + if len(r.AllowedCIDRs) > 0 { + res.AllowedCidrs = &r.AllowedCIDRs + } + if len(r.BlockedCIDRs) > 0 { + res.BlockedCidrs = &r.BlockedCIDRs + } + if len(r.AllowedCountries) > 0 { + res.AllowedCountries = &r.AllowedCountries + } + if len(r.BlockedCountries) > 0 { + res.BlockedCountries = &r.BlockedCountries + } + return res +} + +func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions { + if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 { + return nil + } + return &proto.AccessRestrictions{ + AllowedCidrs: r.AllowedCIDRs, + BlockedCidrs: r.BlockedCIDRs, + AllowedCountries: r.AllowedCountries, + BlockedCountries: r.BlockedCountries, + } +} + func (s *Service) Validate() error { if s.Name == "" { return errors.New("service name is required") @@ -557,6 +695,13 @@ func (s *Service) Validate() error { s.Mode = ModeHTTP } + if err := validateHeaderAuths(s.Auth.HeaderAuths); err != nil { + return err + } + if err := validateAccessRestrictions(&s.Restrictions); err != nil { + return err + } + switch s.Mode { case ModeHTTP: return s.validateHTTPMode() @@ -657,6 +802,21 @@ func (s *Service) validateL4Target(target *Target) error { if target.Path != nil && *target.Path != "" && *target.Path != "/" { return errors.New("path is not supported for L4 services") } + if target.Options.SessionIdleTimeout < 0 { + return errors.New("session_idle_timeout must be positive for L4 services") + } + if target.Options.RequestTimeout < 0 { + return errors.New("request_timeout must be positive for L4 services") + } + if target.Options.SkipTLSVerify { + return errors.New("skip_tls_verify is not supported for L4 services") + } + if target.Options.PathRewrite != "" { + return errors.New("path_rewrite is not supported for L4 services") + } + if len(target.Options.CustomHeaders) > 0 { + return errors.New("custom_headers is not supported for L4 services") + } return nil } @@ -688,11 +848,9 @@ func IsPortBasedProtocol(mode string) bool { } const ( - maxRequestTimeout = 5 * time.Minute - maxSessionIdleTimeout = 10 * time.Minute - maxCustomHeaders = 16 - maxHeaderKeyLen = 128 - maxHeaderValueLen = 4096 + maxCustomHeaders = 16 + maxHeaderKeyLen = 128 + maxHeaderValueLen = 4096 ) // httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition. @@ -731,22 +889,12 @@ func validateTargetOptions(idx int, opts *TargetOptions) error { return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite) } - if opts.RequestTimeout != 0 { - if opts.RequestTimeout <= 0 { - return fmt.Errorf("target %d: request_timeout must be positive", idx) - } - if opts.RequestTimeout > maxRequestTimeout { - return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout) - } + if opts.RequestTimeout < 0 { + return fmt.Errorf("target %d: request_timeout must be positive", idx) } - if opts.SessionIdleTimeout != 0 { - if opts.SessionIdleTimeout <= 0 { - return fmt.Errorf("target %d: session_idle_timeout must be positive", idx) - } - if opts.SessionIdleTimeout > maxSessionIdleTimeout { - return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout) - } + if opts.SessionIdleTimeout < 0 { + return fmt.Errorf("target %d: session_idle_timeout must be positive", idx) } if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil { @@ -796,6 +944,93 @@ func containsCRLF(s string) bool { return strings.ContainsAny(s, "\r\n") } +func validateHeaderAuths(headers []*HeaderAuthConfig) error { + seen := make(map[string]struct{}) + for i, h := range headers { + if h == nil || !h.Enabled { + continue + } + if h.Header == "" { + return fmt.Errorf("header_auths[%d]: header name is required", i) + } + if !httpHeaderNameRe.MatchString(h.Header) { + return fmt.Errorf("header_auths[%d]: header name %q is not a valid HTTP header name", i, h.Header) + } + canonical := http.CanonicalHeaderKey(h.Header) + if _, ok := hopByHopHeaders[canonical]; ok { + return fmt.Errorf("header_auths[%d]: header %q is a hop-by-hop header and cannot be used for auth", i, h.Header) + } + if _, ok := reservedHeaders[canonical]; ok { + return fmt.Errorf("header_auths[%d]: header %q is managed by the proxy and cannot be used for auth", i, h.Header) + } + if canonical == "Host" { + return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i) + } + if _, dup := seen[canonical]; dup { + return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header) + } + seen[canonical] = struct{}{} + if len(h.Value) > maxHeaderValueLen { + return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen) + } + } + return nil +} + +const ( + maxCIDREntries = 200 + maxCountryEntries = 50 +) + +// validateAccessRestrictions validates and normalizes access restriction +// entries. Country codes are uppercased in place. +func validateAccessRestrictions(r *AccessRestrictions) error { + if len(r.AllowedCIDRs) > maxCIDREntries { + return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries) + } + if len(r.BlockedCIDRs) > maxCIDREntries { + return fmt.Errorf("blocked_cidrs: exceeds maximum of %d entries", maxCIDREntries) + } + if len(r.AllowedCountries) > maxCountryEntries { + return fmt.Errorf("allowed_countries: exceeds maximum of %d entries", maxCountryEntries) + } + if len(r.BlockedCountries) > maxCountryEntries { + return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries) + } + + for i, raw := range r.AllowedCIDRs { + prefix, err := netip.ParsePrefix(raw) + if err != nil { + return fmt.Errorf("allowed_cidrs[%d]: %w", i, err) + } + if prefix != prefix.Masked() { + return fmt.Errorf("allowed_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked()) + } + } + for i, raw := range r.BlockedCIDRs { + prefix, err := netip.ParsePrefix(raw) + if err != nil { + return fmt.Errorf("blocked_cidrs[%d]: %w", i, err) + } + if prefix != prefix.Masked() { + return fmt.Errorf("blocked_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked()) + } + } + for i, code := range r.AllowedCountries { + if len(code) != 2 { + return fmt.Errorf("allowed_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code) + } + r.AllowedCountries[i] = strings.ToUpper(code) + } + for i, code := range r.BlockedCountries { + if len(code) != 2 { + return fmt.Errorf("blocked_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code) + } + r.BlockedCountries[i] = strings.ToUpper(code) + } + return nil +} + func (s *Service) EventMeta() map[string]any { meta := map[string]any{ "name": s.Name, @@ -827,9 +1062,17 @@ func (s *Service) EventMeta() map[string]any { } func (s *Service) isAuthEnabled() bool { - return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) || + if (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) || (s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) || - (s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) + (s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) { + return true + } + for _, h := range s.Auth.HeaderAuths { + if h != nil && h.Enabled { + return true + } + } + return false } func (s *Service) Copy() *Service { @@ -866,6 +1109,16 @@ func (s *Service) Copy() *Service { } authCopy.BearerAuth = &ba } + if len(s.Auth.HeaderAuths) > 0 { + authCopy.HeaderAuths = make([]*HeaderAuthConfig, len(s.Auth.HeaderAuths)) + for i, h := range s.Auth.HeaderAuths { + if h == nil { + continue + } + hCopy := *h + authCopy.HeaderAuths[i] = &hCopy + } + } return &Service{ ID: s.ID, @@ -878,6 +1131,7 @@ func (s *Service) Copy() *Service { PassHostHeader: s.PassHostHeader, RewriteRedirects: s.RewriteRedirects, Auth: authCopy, + Restrictions: s.Restrictions.Copy(), Meta: s.Meta, SessionPrivateKey: s.SessionPrivateKey, SessionPublicKey: s.SessionPublicKey, diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index a8a8ae5d6..9daf729fe 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -120,9 +120,9 @@ func TestValidateTargetOptions_RequestTimeout(t *testing.T) { }{ {"valid 30s", 30 * time.Second, ""}, {"valid 2m", 2 * time.Minute, ""}, + {"valid 10m", 10 * time.Minute, ""}, {"zero is fine", 0, ""}, {"negative", -1 * time.Second, "must be positive"}, - {"exceeds max", 10 * time.Minute, "exceeds maximum"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 31a0ba0db..fd993fb40 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "net/http" "net/url" "strings" "sync" @@ -493,16 +494,17 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo // should be set on the copy. func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { return &proto.ProxyMapping{ - Type: m.Type, - Id: m.Id, - AccountId: m.AccountId, - Domain: m.Domain, - Path: m.Path, - Auth: m.Auth, - PassHostHeader: m.PassHostHeader, - RewriteRedirects: m.RewriteRedirects, - Mode: m.Mode, - ListenPort: m.ListenPort, + Type: m.Type, + Id: m.Id, + AccountId: m.AccountId, + Domain: m.Domain, + Path: m.Path, + Auth: m.Auth, + PassHostHeader: m.PassHostHeader, + RewriteRedirects: m.RewriteRedirects, + Mode: m.Mode, + ListenPort: m.ListenPort, + AccessRestrictions: m.AccessRestrictions, } } @@ -561,6 +563,8 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth) case *proto.AuthenticateRequest_Password: return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth) + case *proto.AuthenticateRequest_HeaderAuth: + return s.authenticateHeader(ctx, req.GetId(), v, service.Auth.HeaderAuths) default: return false, "", "" } @@ -594,6 +598,35 @@ func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID return true, "password-user", proxyauth.MethodPassword } +func (s *ProxyServiceServer) authenticateHeader(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_HeaderAuth, auths []*rpservice.HeaderAuthConfig) (bool, string, proxyauth.Method) { + if len(auths) == 0 { + log.WithContext(ctx).Debugf("header authentication attempted but no header auths configured for service %s", serviceID) + return false, "", "" + } + + headerName := http.CanonicalHeaderKey(req.HeaderAuth.GetHeaderName()) + + var lastErr error + for _, auth := range auths { + if auth == nil || !auth.Enabled { + continue + } + if headerName != "" && http.CanonicalHeaderKey(auth.Header) != headerName { + continue + } + if err := argon2id.Verify(req.HeaderAuth.GetHeaderValue(), auth.Value); err != nil { + lastErr = err + continue + } + return true, "header-user", proxyauth.MethodHeader + } + + if lastErr != nil { + s.logAuthenticationError(ctx, lastErr, "Header") + } + return false, "", "" +} + func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) { if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) { log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType) @@ -752,6 +785,9 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU if err != nil { return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) } + if redirectURL.Scheme != "https" && redirectURL.Scheme != "http" { + return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme") + } // Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection. services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId()) if err != nil { @@ -836,12 +872,9 @@ func (s *ProxyServiceServer) generateHMAC(input string) string { // ValidateState validates the state parameter from an OAuth callback. // Returns the original redirect URL if valid, or an error if invalid. +// The HMAC is verified before consuming the PKCE verifier to prevent +// an attacker from invalidating a legitimate user's auth flow. func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) { - verifier, ok := s.pkceVerifierStore.LoadAndDelete(state) - if !ok { - return "", "", errors.New("no verifier for state") - } - // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) parts := strings.Split(state, "|") if len(parts) != 3 { @@ -865,6 +898,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL return "", "", errors.New("invalid state signature") } + // Consume the PKCE verifier only after HMAC validation passes. + verifier, ok := s.pkceVerifierStore.LoadAndDelete(state) + if !ok { + return "", "", errors.New("no verifier for state") + } + return verifier, redirectURL, nil } diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index c0179a1c4..30fd493e8 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -44,6 +44,12 @@ type Record struct { GeonameID uint `maxminddb:"geoname_id"` ISOCode string `maxminddb:"iso_code"` } `maxminddb:"country"` + Subdivisions []struct { + ISOCode string `maxminddb:"iso_code"` + Names struct { + En string `maxminddb:"en"` + } `maxminddb:"names"` + } `maxminddb:"subdivisions"` } type City struct { diff --git a/proxy/Dockerfile b/proxy/Dockerfile index 096c71f21..e64680fd6 100644 --- a/proxy/Dockerfile +++ b/proxy/Dockerfile @@ -10,7 +10,7 @@ FROM gcr.io/distroless/base:debug COPY netbird-proxy /go/bin/netbird-proxy COPY --from=builder /tmp/passwd /etc/passwd COPY --from=builder /tmp/group /etc/group -COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird +COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs USER netbird:netbird ENV HOME=/var/lib/netbird diff --git a/proxy/Dockerfile.multistage b/proxy/Dockerfile.multistage index 2e3ac3561..01e342c0e 100644 --- a/proxy/Dockerfile.multistage +++ b/proxy/Dockerfile.multistage @@ -28,7 +28,7 @@ FROM gcr.io/distroless/base:debug COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy COPY --from=builder /tmp/passwd /etc/passwd COPY --from=builder /tmp/group /etc/group -COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird +COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs USER netbird:netbird ENV HOME=/var/lib/netbird diff --git a/proxy/auth/auth.go b/proxy/auth/auth.go index 14caa03b3..ca9c260b7 100644 --- a/proxy/auth/auth.go +++ b/proxy/auth/auth.go @@ -13,10 +13,11 @@ import ( type Method string -var ( +const ( MethodPassword Method = "password" MethodPIN Method = "pin" MethodOIDC Method = "oidc" + MethodHeader Method = "header" ) func (m Method) String() string { diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index d82f5b7fc..a2252cc20 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -36,31 +36,33 @@ var ( var ( logLevel string - debugLogs bool - mgmtAddr string - addr string - proxyDomain string - defaultDialTimeout time.Duration - certDir string - acmeCerts bool - acmeAddr string - acmeDir string - acmeEABKID string - acmeEABHMACKey string - acmeChallengeType string - debugEndpoint bool - debugEndpointAddr string - healthAddr string - forwardedProto string - trustedProxies string - certFile string - certKeyFile string - certLockMethod string - wildcardCertDir string - wgPort uint16 - proxyProtocol bool - preSharedKey string - supportsCustomPorts bool + debugLogs bool + mgmtAddr string + addr string + proxyDomain string + maxDialTimeout time.Duration + maxSessionIdleTimeout time.Duration + certDir string + acmeCerts bool + acmeAddr string + acmeDir string + acmeEABKID string + acmeEABHMACKey string + acmeChallengeType string + debugEndpoint bool + debugEndpointAddr string + healthAddr string + forwardedProto string + trustedProxies string + certFile string + certKeyFile string + certLockMethod string + wildcardCertDir string + wgPort uint16 + proxyProtocol bool + preSharedKey string + supportsCustomPorts bool + geoDataDir string ) var rootCmd = &cobra.Command{ @@ -99,7 +101,9 @@ func init() { rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough") - rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)") + rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)") + rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)") + rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)") } // Execute runs the root command. @@ -177,17 +181,15 @@ func runServer(cmd *cobra.Command, args []string) error { ProxyProtocol: proxyProtocol, PreSharedKey: preSharedKey, SupportsCustomPorts: supportsCustomPorts, - DefaultDialTimeout: defaultDialTimeout, + MaxDialTimeout: maxDialTimeout, + MaxSessionIdleTimeout: maxSessionIdleTimeout, + GeoDataDir: geoDataDir, } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) defer stop() - if err := srv.ListenAndServe(ctx, addr); err != nil { - logger.Error(err) - return err - } - return nil + return srv.ListenAndServe(ctx, addr) } func envBoolOrDefault(key string, def bool) bool { @@ -197,6 +199,7 @@ func envBoolOrDefault(key string, def bool) bool { } parsed, err := strconv.ParseBool(v) if err != nil { + log.Warnf("parse %s=%q: %v, using default %v", key, v, err, def) return def } return parsed @@ -217,6 +220,7 @@ func envUint16OrDefault(key string, def uint16) uint16 { } parsed, err := strconv.ParseUint(v, 10, 16) if err != nil { + log.Warnf("parse %s=%q: %v, using default %d", key, v, err, def) return def } return uint16(parsed) @@ -229,6 +233,7 @@ func envDurationOrDefault(key string, def time.Duration) time.Duration { } parsed, err := time.ParseDuration(v) if err != nil { + log.Warnf("parse %s=%q: %v, using default %s", key, v, err, def) return def } return parsed diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index 5b05ab195..3ed3275b5 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -4,6 +4,7 @@ import ( "context" "net/netip" "sync" + "sync/atomic" "time" "github.com/rs/xid" @@ -22,6 +23,16 @@ const ( usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours logSendTimeout = 10 * time.Second + + // denyCooldown is the min interval between deny log entries per service+reason + // to prevent flooding from denied connections (e.g. UDP packets from blocked IPs). + denyCooldown = 10 * time.Second + + // maxDenyBuckets caps tracked deny rate-limit entries to bound memory under DDoS. + maxDenyBuckets = 10000 + + // maxLogWorkers caps concurrent gRPC send goroutines. + maxLogWorkers = 4096 ) type domainUsage struct { @@ -38,6 +49,18 @@ type gRPCClient interface { SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error) } +// denyBucketKey identifies a rate-limited deny log stream. +type denyBucketKey struct { + ServiceID types.ServiceID + Reason string +} + +// denyBucket tracks rate-limited deny log entries. +type denyBucket struct { + lastLogged time.Time + suppressed int64 +} + // Logger sends access log entries to the management server via gRPC. type Logger struct { client gRPCClient @@ -47,7 +70,12 @@ type Logger struct { usageMux sync.Mutex domainUsage map[string]*domainUsage + denyMu sync.Mutex + denyBuckets map[denyBucketKey]*denyBucket + + logSem chan struct{} cleanupCancel context.CancelFunc + dropped atomic.Int64 } // NewLogger creates a new access log Logger. The trustedProxies parameter @@ -64,6 +92,8 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre logger: logger, trustedProxies: trustedProxies, domainUsage: make(map[string]*domainUsage), + denyBuckets: make(map[denyBucketKey]*denyBucket), + logSem: make(chan struct{}, maxLogWorkers), cleanupCancel: cancel, } @@ -83,7 +113,7 @@ func (l *Logger) Close() { type logEntry struct { ID string AccountID types.AccountID - ServiceId types.ServiceID + ServiceID types.ServiceID Host string Path string DurationMs int64 @@ -91,7 +121,7 @@ type logEntry struct { ResponseCode int32 SourceIP netip.Addr AuthMechanism string - UserId string + UserID string AuthSuccess bool BytesUpload int64 BytesDownload int64 @@ -118,6 +148,10 @@ type L4Entry struct { DurationMs int64 BytesUpload int64 BytesDownload int64 + // DenyReason, when non-empty, indicates the connection was denied. + // Values match the HTTP auth mechanism strings: "ip_restricted", + // "country_restricted", "geo_unavailable". + DenyReason string } // LogL4 sends an access log entry for a layer-4 connection (TCP or UDP). @@ -126,7 +160,7 @@ func (l *Logger) LogL4(entry L4Entry) { le := logEntry{ ID: xid.New().String(), AccountID: entry.AccountID, - ServiceId: entry.ServiceID, + ServiceID: entry.ServiceID, Protocol: entry.Protocol, Host: entry.Host, SourceIP: entry.SourceIP, @@ -134,10 +168,47 @@ func (l *Logger) LogL4(entry L4Entry) { BytesUpload: entry.BytesUpload, BytesDownload: entry.BytesDownload, } + if entry.DenyReason != "" { + if !l.allowDenyLog(entry.ServiceID, entry.DenyReason) { + return + } + le.AuthMechanism = entry.DenyReason + le.AuthSuccess = false + } l.log(le) l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload) } +// allowDenyLog rate-limits deny log entries per service+reason combination. +func (l *Logger) allowDenyLog(serviceID types.ServiceID, reason string) bool { + key := denyBucketKey{ServiceID: serviceID, Reason: reason} + now := time.Now() + + l.denyMu.Lock() + defer l.denyMu.Unlock() + + b, ok := l.denyBuckets[key] + if !ok { + if len(l.denyBuckets) >= maxDenyBuckets { + return false + } + l.denyBuckets[key] = &denyBucket{lastLogged: now} + return true + } + + if now.Sub(b.lastLogged) >= denyCooldown { + if b.suppressed > 0 { + l.logger.Debugf("access restriction: suppressed %d deny log entries for %s (%s)", b.suppressed, serviceID, reason) + } + b.lastLogged = now + b.suppressed = 0 + return true + } + + b.suppressed++ + return false +} + func (l *Logger) log(entry logEntry) { // Fire off the log request in a separate routine. // This increases the possibility of losing a log message @@ -147,12 +218,21 @@ func (l *Logger) log(entry logEntry) { // There is also a chance that log messages will arrive at // the server out of order; however, the timestamp should // allow for resolving that on the server. - now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary. + now := timestamppb.Now() + select { + case l.logSem <- struct{}{}: + default: + total := l.dropped.Add(1) + l.logger.Debugf("access log send dropped: worker limit reached (total dropped: %d)", total) + return + } go func() { + defer func() { <-l.logSem }() logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout) defer cancel() + // Only OIDC sessions have a meaningful user identity. if entry.AuthMechanism != auth.MethodOIDC.String() { - entry.UserId = "" + entry.UserID = "" } var sourceIP string @@ -165,7 +245,7 @@ func (l *Logger) log(entry logEntry) { LogId: entry.ID, AccountId: string(entry.AccountID), Timestamp: now, - ServiceId: string(entry.ServiceId), + ServiceId: string(entry.ServiceID), Host: entry.Host, Path: entry.Path, DurationMs: entry.DurationMs, @@ -173,7 +253,7 @@ func (l *Logger) log(entry logEntry) { ResponseCode: entry.ResponseCode, SourceIp: sourceIP, AuthMechanism: entry.AuthMechanism, - UserId: entry.UserId, + UserId: entry.UserID, AuthSuccess: entry.AuthSuccess, BytesUpload: entry.BytesUpload, BytesDownload: entry.BytesDownload, @@ -181,7 +261,7 @@ func (l *Logger) log(entry logEntry) { }, }); err != nil { l.logger.WithFields(log.Fields{ - "service_id": entry.ServiceId, + "service_id": entry.ServiceID, "host": entry.Host, "path": entry.Path, "duration": entry.DurationMs, @@ -189,7 +269,7 @@ func (l *Logger) log(entry logEntry) { "response_code": entry.ResponseCode, "source_ip": sourceIP, "auth_mechanism": entry.AuthMechanism, - "user_id": entry.UserId, + "user_id": entry.UserID, "auth_success": entry.AuthSuccess, "error": err, }).Error("Error sending access log on gRPC connection") @@ -248,7 +328,7 @@ func (l *Logger) trackUsage(domain string, bytesTransferred int64) { } } -// cleanupStaleUsage removes usage entries for domains that have been inactive. +// cleanupStaleUsage removes usage and deny-rate-limit entries that have been inactive. func (l *Logger) cleanupStaleUsage(ctx context.Context) { ticker := time.NewTicker(usageCleanupPeriod) defer ticker.Stop() @@ -258,20 +338,41 @@ func (l *Logger) cleanupStaleUsage(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - l.usageMux.Lock() now := time.Now() - removed := 0 - for domain, usage := range l.domainUsage { - if now.Sub(usage.lastActivity) > usageInactiveWindow { - delete(l.domainUsage, domain) - removed++ - } - } - l.usageMux.Unlock() - - if removed > 0 { - l.logger.Debugf("cleaned up %d stale domain usage entries", removed) - } + l.cleanupDomainUsage(now) + l.cleanupDenyBuckets(now) } } } + +func (l *Logger) cleanupDomainUsage(now time.Time) { + l.usageMux.Lock() + defer l.usageMux.Unlock() + + removed := 0 + for domain, usage := range l.domainUsage { + if now.Sub(usage.lastActivity) > usageInactiveWindow { + delete(l.domainUsage, domain) + removed++ + } + } + if removed > 0 { + l.logger.Debugf("cleaned up %d stale domain usage entries", removed) + } +} + +func (l *Logger) cleanupDenyBuckets(now time.Time) { + l.denyMu.Lock() + defer l.denyMu.Unlock() + + removed := 0 + for key, bucket := range l.denyBuckets { + if now.Sub(bucket.lastLogged) > usageInactiveWindow { + delete(l.denyBuckets, key) + removed++ + } + } + if removed > 0 { + l.logger.Debugf("cleaned up %d stale deny rate-limit entries", removed) + } +} diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 593a77ef2..81c790b17 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/proxy/web" ) +// Middleware wraps an HTTP handler to log access entries and resolve client IPs. func (l *Logger) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip logging for internal proxy assets (CSS, JS, etc.) @@ -47,8 +48,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { // Create a mutable struct to capture data from downstream handlers. // We pass a pointer in the context - the pointer itself flows down immutably, // but the struct it points to can be mutated by inner handlers. - capturedData := &proxy.CapturedData{RequestID: requestID} + capturedData := proxy.NewCapturedData(requestID) capturedData.SetClientIP(sourceIp) + ctx := proxy.WithCapturedData(r.Context(), capturedData) start := time.Now() @@ -66,8 +68,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { entry := logEntry{ ID: requestID, - ServiceId: capturedData.GetServiceId(), - AccountID: capturedData.GetAccountId(), + ServiceID: capturedData.GetServiceID(), + AccountID: capturedData.GetAccountID(), Host: host, Path: r.URL.Path, DurationMs: duration.Milliseconds(), @@ -75,14 +77,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { ResponseCode: int32(sw.status), SourceIP: sourceIp, AuthMechanism: capturedData.GetAuthMethod(), - UserId: capturedData.GetUserID(), + UserID: capturedData.GetUserID(), AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, BytesUpload: bytesUpload, BytesDownload: bytesDownload, Protocol: ProtocolHTTP, } l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", - requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId()) + requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID()) l.log(entry) diff --git a/proxy/internal/auth/header.go b/proxy/internal/auth/header.go new file mode 100644 index 000000000..194800a49 --- /dev/null +++ b/proxy/internal/auth/header.go @@ -0,0 +1,69 @@ +package auth + +import ( + "errors" + "fmt" + "net/http" + + "github.com/netbirdio/netbird/proxy/auth" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// ErrHeaderAuthFailed indicates that the header was present but the +// credential did not validate. Callers should return 401 instead of +// falling through to other auth schemes. +var ErrHeaderAuthFailed = errors.New("header authentication failed") + +// Header implements header-based authentication. The proxy checks for the +// configured header in each request and validates its value via gRPC. +type Header struct { + id types.ServiceID + accountId types.AccountID + headerName string + client authenticator +} + +// NewHeader creates a Header authentication scheme for the given header name. +func NewHeader(client authenticator, id types.ServiceID, accountId types.AccountID, headerName string) Header { + return Header{ + id: id, + accountId: accountId, + headerName: headerName, + client: client, + } +} + +// Type returns auth.MethodHeader. +func (Header) Type() auth.Method { + return auth.MethodHeader +} + +// Authenticate checks for the configured header in the request. If absent, +// returns empty (unauthenticated). If present, validates via gRPC. +func (h Header) Authenticate(r *http.Request) (string, string, error) { + value := r.Header.Get(h.headerName) + if value == "" { + return "", "", nil + } + + res, err := h.client.Authenticate(r.Context(), &proto.AuthenticateRequest{ + Id: string(h.id), + AccountId: string(h.accountId), + Request: &proto.AuthenticateRequest_HeaderAuth{ + HeaderAuth: &proto.HeaderAuthRequest{ + HeaderValue: value, + HeaderName: h.headerName, + }, + }, + }) + if err != nil { + return "", "", fmt.Errorf("authenticate header: %w", err) + } + + if res.GetSuccess() { + return res.GetSessionToken(), "", nil + } + + return "", "", ErrHeaderAuthFailed +} diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 3cf86e4b3..670cafb68 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -4,9 +4,12 @@ import ( "context" "crypto/ed25519" "encoding/base64" + "errors" "fmt" + "html" "net" "net/http" + "net/netip" "net/url" "sync" "time" @@ -16,11 +19,16 @@ import ( "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/shared/management/proto" ) +// errValidationUnavailable indicates that session validation failed due to +// an infrastructure error (e.g. gRPC unavailable), not an invalid token. +var errValidationUnavailable = errors.New("session validation unavailable") + type authenticator interface { Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error) } @@ -40,12 +48,14 @@ type Scheme interface { Authenticate(*http.Request) (token string, promptData string, err error) } +// DomainConfig holds the authentication and restriction settings for a protected domain. type DomainConfig struct { Schemes []Scheme SessionPublicKey ed25519.PublicKey SessionExpiration time.Duration AccountID types.AccountID ServiceID types.ServiceID + IPRestrictions *restrict.Filter } type validationResult struct { @@ -54,17 +64,18 @@ type validationResult struct { DeniedReason string } +// Middleware applies per-domain authentication and IP restriction checks. type Middleware struct { domainsMux sync.RWMutex domains map[string]DomainConfig logger *log.Logger sessionValidator SessionValidator + geo restrict.GeoResolver } -// NewMiddleware creates a new authentication middleware. -// The sessionValidator is optional; if nil, OIDC session tokens will be validated -// locally without group access checks. -func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware { +// NewMiddleware creates a new authentication middleware. The sessionValidator is +// optional; if nil, OIDC session tokens are validated locally without group access checks. +func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo restrict.GeoResolver) *Middleware { if logger == nil { logger = log.StandardLogger() } @@ -72,18 +83,12 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middl domains: make(map[string]DomainConfig), logger: logger, sessionValidator: sessionValidator, + geo: geo, } } -// Protect applies authentication middleware to the passed handler. -// For each incoming request it will be checked against the middleware's -// internal list of protected domains. -// If the Host domain in the inbound request is not present, then it will -// simply be passed through. -// However, if the Host domain is present, then the specified authentication -// schemes for that domain will be applied to the request. -// In the event that no authentication schemes are defined for the domain, -// then the request will also be simply passed through. +// Protect wraps next with per-domain authentication and IP restriction checks. +// Requests whose Host is not registered pass through unchanged. func (mw *Middleware) Protect(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { host, _, err := net.SplitHostPort(r.Host) @@ -94,8 +99,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { config, exists := mw.getDomainConfig(host) mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists) - // Domains that are not configured here or have no authentication schemes applied should simply pass through. - if !exists || len(config.Schemes) == 0 { + if !exists { next.ServeHTTP(w, r) return } @@ -103,6 +107,16 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { // Set account and service IDs in captured data for access logging. setCapturedIDs(r, config) + if !mw.checkIPRestrictions(w, r, config) { + return + } + + // Domains with no authentication schemes pass through after IP checks. + if len(config.Schemes) == 0 { + next.ServeHTTP(w, r) + return + } + if mw.handleOAuthCallbackError(w, r) { return } @@ -111,6 +125,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { return } + if mw.forwardWithHeaderAuth(w, r, host, config, next) { + return + } + mw.authenticateWithSchemes(w, r, host, config) }) } @@ -124,11 +142,65 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) { func setCapturedIDs(r *http.Request, config DomainConfig) { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetAccountId(config.AccountID) - cd.SetServiceId(config.ServiceID) + cd.SetAccountID(config.AccountID) + cd.SetServiceID(config.ServiceID) } } +// checkIPRestrictions validates the client IP against the domain's IP restrictions. +// Uses the resolved client IP from CapturedData (which accounts for trusted proxies) +// rather than r.RemoteAddr directly. +func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request, config DomainConfig) bool { + if config.IPRestrictions == nil { + return true + } + + clientIP := mw.resolveClientIP(r) + if !clientIP.IsValid() { + mw.logger.Debugf("IP restriction: cannot resolve client address for %q, denying", r.RemoteAddr) + http.Error(w, "Forbidden", http.StatusForbidden) + return false + } + + verdict := config.IPRestrictions.Check(clientIP, mw.geo) + if verdict == restrict.Allow { + return true + } + + reason := verdict.String() + mw.blockIPRestriction(r, reason) + http.Error(w, "Forbidden", http.StatusForbidden) + return false +} + +// resolveClientIP extracts the real client IP from CapturedData, falling back to r.RemoteAddr. +func (mw *Middleware) resolveClientIP(r *http.Request) netip.Addr { + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + if ip := cd.GetClientIP(); ip.IsValid() { + return ip + } + } + + clientIPStr, _, _ := net.SplitHostPort(r.RemoteAddr) + if clientIPStr == "" { + clientIPStr = r.RemoteAddr + } + addr, err := netip.ParseAddr(clientIPStr) + if err != nil { + return netip.Addr{} + } + return addr.Unmap() +} + +// blockIPRestriction sets captured data fields for an IP-restriction block event. +func (mw *Middleware) blockIPRestriction(r *http.Request, reason string) { + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + cd.SetAuthMethod(reason) + } + mw.logger.Debugf("IP restriction: %s for %s", reason, r.RemoteAddr) +} + // handleOAuthCallbackError checks for error query parameters from an OAuth // callback and renders the access denied page if present. func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool { @@ -146,6 +218,8 @@ func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Re errDesc := r.URL.Query().Get("error_description") if errDesc == "" { errDesc = "An error occurred during authentication" + } else { + errDesc = html.EscapeString(errDesc) } web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID) return true @@ -170,6 +244,85 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re return true } +// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates, +// the request is forwarded directly (no redirect), which is important for API clients. +func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool { + for _, scheme := range config.Schemes { + hdr, ok := scheme.(Header) + if !ok { + continue + } + + handled := mw.tryHeaderScheme(w, r, host, config, hdr, next) + if handled { + return true + } + } + return false +} + +func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, hdr Header, next http.Handler) bool { + token, _, err := hdr.Authenticate(r) + if err != nil { + return mw.handleHeaderAuthError(w, r, err) + } + if token == "" { + return false + } + + result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader) + if err != nil { + setHeaderCapturedData(r.Context(), "") + status := http.StatusBadRequest + msg := "invalid session token" + if errors.Is(err, errValidationUnavailable) { + status = http.StatusBadGateway + msg = "authentication service unavailable" + } + http.Error(w, msg, status) + return true + } + + if !result.Valid { + setHeaderCapturedData(r.Context(), result.UserID) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return true + } + + setSessionCookie(w, token, config.SessionExpiration) + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetUserID(result.UserID) + cd.SetAuthMethod(auth.MethodHeader.String()) + } + + next.ServeHTTP(w, r) + return true +} + +func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool { + if errors.Is(err, ErrHeaderAuthFailed) { + setHeaderCapturedData(r.Context(), "") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return true + } + mw.logger.WithField("scheme", "header").Warnf("header auth infrastructure error: %v", err) + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + } + http.Error(w, "authentication service unavailable", http.StatusBadGateway) + return true +} + +func setHeaderCapturedData(ctx context.Context, userID string) { + cd := proxy.CapturedDataFromContext(ctx) + if cd == nil { + return + } + cd.SetOrigin(proxy.OriginAuth) + cd.SetAuthMethod(auth.MethodHeader.String()) + cd.SetUserID(userID) +} + // authenticateWithSchemes tries each configured auth scheme in order. // On success it sets a session cookie and redirects; on failure it renders the login page. func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) { @@ -217,7 +370,13 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re cd.SetOrigin(proxy.OriginAuth) cd.SetAuthMethod(scheme.Type().String()) } - http.Error(w, err.Error(), http.StatusBadRequest) + status := http.StatusBadRequest + msg := "invalid session token" + if errors.Is(err, errValidationUnavailable) { + status = http.StatusBadGateway + msg = "authentication service unavailable" + } + http.Error(w, msg, status) return } @@ -233,7 +392,21 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re return } - expiration := config.SessionExpiration + setSessionCookie(w, token, config.SessionExpiration) + + // Redirect instead of forwarding the auth POST to the backend. + // The browser will follow with a GET carrying the new session cookie. + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + cd.SetUserID(result.UserID) + cd.SetAuthMethod(scheme.Type().String()) + } + redirectURL := stripSessionTokenParam(r.URL) + http.Redirect(w, r, redirectURL, http.StatusSeeOther) +} + +// setSessionCookie writes a session cookie with secure defaults. +func setSessionCookie(w http.ResponseWriter, token string, expiration time.Duration) { if expiration == 0 { expiration = auth.DefaultSessionExpiry } @@ -245,16 +418,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re SameSite: http.SameSiteLaxMode, MaxAge: int(expiration.Seconds()), }) - - // Redirect instead of forwarding the auth POST to the backend. - // The browser will follow with a GET carrying the new session cookie. - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(proxy.OriginAuth) - cd.SetUserID(result.UserID) - cd.SetAuthMethod(scheme.Type().String()) - } - redirectURL := stripSessionTokenParam(r.URL) - http.Redirect(w, r, redirectURL, http.StatusSeeOther) } // wasCredentialSubmitted checks if credentials were submitted for the given auth method. @@ -275,13 +438,14 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool { // session JWTs. Returns an error if the key is missing or invalid. // Callers must not serve the domain if this returns an error, to avoid // exposing an unauthenticated service. -func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error { +func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error { if len(schemes) == 0 { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() mw.domains[domain] = DomainConfig{ - AccountID: accountID, - ServiceID: serviceID, + AccountID: accountID, + ServiceID: serviceID, + IPRestrictions: ipRestrictions, } return nil } @@ -302,30 +466,28 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st SessionExpiration: expiration, AccountID: accountID, ServiceID: serviceID, + IPRestrictions: ipRestrictions, } return nil } +// RemoveDomain unregisters authentication for the given domain. func (mw *Middleware) RemoveDomain(domain string) { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() delete(mw.domains, domain) } -// validateSessionToken validates a session token, optionally checking group access via gRPC. -// For OIDC tokens with a configured validator, it calls ValidateSession to check group access. -// For other auth methods (PIN, password), it validates the JWT locally. -// Returns a validationResult with user ID and validity status, or error for invalid tokens. +// validateSessionToken validates a session token. OIDC tokens with a configured +// validator go through gRPC for group access checks; other methods validate locally. func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) { - // For OIDC with a session validator, call the gRPC service to check group access if method == auth.MethodOIDC && mw.sessionValidator != nil { resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{ Domain: host, SessionToken: token, }) if err != nil { - mw.logger.WithError(err).Error("ValidateSession gRPC call failed") - return nil, fmt.Errorf("session validation failed") + return nil, fmt.Errorf("%w: %w", errValidationUnavailable, err) } if !resp.Valid { mw.logger.WithFields(log.Fields{ @@ -342,7 +504,6 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri return &validationResult{UserID: resp.UserId, Valid: true}, nil } - // For non-OIDC methods or when no validator is configured, validate JWT locally userID, _, err := auth.ValidateSessionJWT(token, host, publicKey) if err != nil { return nil, err diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index 7d9ac1bd5..a4924d380 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -1,11 +1,14 @@ package auth import ( + "context" "crypto/ed25519" "crypto/rand" "encoding/base64" + "errors" "net/http" "net/http/httptest" + "net/netip" "net/url" "strings" "testing" @@ -14,10 +17,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/restrict" + "github.com/netbirdio/netbird/shared/management/proto" ) func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair { @@ -52,11 +58,11 @@ func newPassthroughHandler() http.Handler { } func TestAddDomain_ValidKey(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil) require.NoError(t, err) mw.domainsMux.RLock() @@ -70,10 +76,10 @@ func TestAddDomain_ValidKey(t *testing.T) { } func TestAddDomain_EmptyKey(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil) require.Error(t, err) assert.Contains(t, err.Error(), "invalid session public key size") @@ -84,10 +90,10 @@ func TestAddDomain_EmptyKey(t *testing.T) { } func TestAddDomain_InvalidBase64(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil) require.Error(t, err) assert.Contains(t, err.Error(), "decode session public key") @@ -98,11 +104,11 @@ func TestAddDomain_InvalidBase64(t *testing.T) { } func TestAddDomain_WrongKeySize(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort")) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil) require.Error(t, err) assert.Contains(t, err.Error(), "invalid session public key size") @@ -113,9 +119,9 @@ func TestAddDomain_WrongKeySize(t *testing.T) { } func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) - err := mw.AddDomain("example.com", nil, "", time.Hour, "", "") + err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil) require.NoError(t, err, "domains with no auth schemes should not require a key") mw.domainsMux.RLock() @@ -125,14 +131,14 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { } func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp1 := generateTestKeyPair(t) kp2 := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "")) - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil)) mw.domainsMux.RLock() config := mw.domains["example.com"] @@ -144,11 +150,11 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { } func TestRemoveDomain(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) mw.RemoveDomain("example.com") @@ -159,7 +165,7 @@ func TestRemoveDomain(t *testing.T) { } func TestProtect_UnknownDomainPassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) handler := mw.Protect(newPassthroughHandler()) req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil) @@ -171,8 +177,8 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) { } func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) - require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "")) + mw := NewMiddleware(log.StandardLogger(), nil, nil) + require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)) handler := mw.Protect(newPassthroughHandler()) @@ -185,11 +191,11 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { } func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -206,11 +212,11 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { } func TestProtect_HostWithPortIsMatched(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -227,16 +233,16 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) { } func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) require.NoError(t, err) - capturedData := &proxy.CapturedData{} + capturedData := proxy.NewCapturedData("") handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cd := proxy.CapturedDataFromContext(r.Context()) require.NotNil(t, cd) @@ -257,11 +263,11 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { } func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) // Sign a token that expired 1 second ago. token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second) @@ -283,11 +289,11 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { } func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) // Token signed for a different domain audience. token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour) @@ -309,12 +315,12 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { } func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp1 := generateTestKeyPair(t) kp2 := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil)) // Token signed with a different private key. token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) @@ -336,7 +342,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { } func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour) @@ -351,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -386,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { } func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{ @@ -395,7 +401,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) handler := mw.Protect(newPassthroughHandler()) @@ -409,7 +415,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { } func TestProtect_MultipleSchemes(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour) @@ -431,7 +437,7 @@ func TestProtect_MultipleSchemes(t *testing.T) { return "", "password", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -451,7 +457,7 @@ func TestProtect_MultipleSchemes(t *testing.T) { } func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) // Return a garbage token that won't validate. @@ -461,7 +467,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { return "invalid-jwt-token", "", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) handler := mw.Protect(newPassthroughHandler()) @@ -473,7 +479,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { } func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) // 32 random bytes that happen to be valid base64 and correct size // but are actually a valid ed25519 public key length-wise. @@ -485,19 +491,19 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { key := base64.StdEncoding.EncodeToString(randomBytes) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "") + err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil) require.NoError(t, err, "any 32-byte key should be accepted at registration time") } func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) // Attempt to overwrite with an invalid key. - err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "") + err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil) require.Error(t, err) // The original valid config should still be intact. @@ -511,7 +517,7 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { } func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) // Scheme that always fails authentication (returns empty token) @@ -521,9 +527,9 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) - capturedData := &proxy.CapturedData{} + capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) // Submit wrong PIN - should capture auth method @@ -539,7 +545,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) { } func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{ @@ -548,9 +554,9 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) { return "", "password", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) - capturedData := &proxy.CapturedData{} + capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) // Submit wrong password - should capture auth method @@ -566,7 +572,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) { } func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) { - mw := NewMiddleware(log.StandardLogger(), nil) + mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{ @@ -575,9 +581,9 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) - capturedData := &proxy.CapturedData{} + capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) // No credentials submitted - should not capture auth method @@ -658,3 +664,271 @@ func TestWasCredentialSubmitted(t *testing.T) { }) } } + +func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + + err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", + restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + remoteAddr string + wantCode int + }{ + {"unparsable address denies", "not-an-ip:1234", http.StatusForbidden}, + {"empty address denies", "", http.StatusForbidden}, + {"allowed address passes", "10.1.2.3:5678", http.StatusOK}, + {"denied address blocked", "192.168.1.1:5678", http.StatusForbidden}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.RemoteAddr = tt.remoteAddr + req.Host = "example.com" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, tt.wantCode, rr.Code) + }) + } +} + +func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) { + // When CapturedData is set (by the access log middleware, which resolves + // trusted proxies), checkIPRestrictions should use that IP, not RemoteAddr. + mw := NewMiddleware(log.StandardLogger(), nil, nil) + + err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", + restrict.ParseFilter([]string{"203.0.113.0/24"}, nil, nil, nil)) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // RemoteAddr is a trusted proxy, but CapturedData has the real client IP. + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.RemoteAddr = "10.0.0.1:5000" + req.Host = "example.com" + + cd := proxy.NewCapturedData("") + cd.SetClientIP(netip.MustParseAddr("203.0.113.50")) + ctx := proxy.WithCapturedData(req.Context(), cd) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, "should use CapturedData IP (203.0.113.50), not RemoteAddr (10.0.0.1)") + + // Same request but CapturedData has a blocked IP. + req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req2.RemoteAddr = "203.0.113.50:5000" + req2.Host = "example.com" + + cd2 := proxy.NewCapturedData("") + cd2.SetClientIP(netip.MustParseAddr("10.0.0.1")) + ctx2 := proxy.WithCapturedData(req2.Context(), cd2) + req2 = req2.WithContext(ctx2) + + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + assert.Equal(t, http.StatusForbidden, rr2.Code, "should use CapturedData IP (10.0.0.1), not RemoteAddr (203.0.113.50)") +} + +func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) { + // Geo is nil, country restrictions are configured: must deny (fail-close). + mw := NewMiddleware(log.StandardLogger(), nil, nil) + + err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", + restrict.ParseFilter(nil, nil, []string{"US"}, nil)) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.RemoteAddr = "1.2.3.4:5678" + req.Host = "example.com" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny") +} + +// mockAuthenticator is a minimal mock for the authenticator gRPC interface +// used by the Header scheme. +type mockAuthenticator struct { + fn func(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) +} + +func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.AuthenticateRequest, _ ...grpc.CallOption) (*proto.AuthenticateResponse, error) { + return m.fn(ctx, in) +} + +// newHeaderSchemeWithToken creates a Header scheme backed by a mock that +// returns a signed session token when the expected header value is provided. +func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header { + t.Helper() + token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour) + require.NoError(t, err) + + mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + ha := req.GetHeaderAuth() + if ha != nil && ha.GetHeaderValue() == expectedValue { + return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil + } + return &proto.AuthenticateResponse{Success: false}, nil + }} + return NewHeader(mock, "svc1", "acc1", headerName) +} + +func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + var backendCalled bool + capturedData := proxy.NewCapturedData("") + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/path", nil) + req.Header.Set("X-API-Key", "secret-key") + req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.True(t, backendCalled, "backend should be called directly for header auth (no redirect)") + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "ok", rec.Body.String()) + + // Session cookie should be set. + var sessionCookie *http.Cookie + for _, c := range rec.Result().Cookies() { + if c.Name == auth.SessionCookieName { + sessionCookie = c + break + } + } + require.NotNil(t, sessionCookie, "session cookie should be set after successful header auth") + assert.True(t, sessionCookie.HttpOnly) + assert.True(t, sessionCookie.Secure) + + assert.Equal(t, "header-user", capturedData.GetUserID()) + assert.Equal(t, "header", capturedData.GetAuthMethod()) +} + +func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") + // Also add a PIN scheme so we can verify fallthrough behavior. + pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + handler := mw.Protect(newPassthroughHandler()) + + // No X-API-Key header: should fall through to PIN login page (401). + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code, "missing header should fall through to login page") +} + +func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + return &proto.AuthenticateResponse{Success: false}, nil + }} + hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + capturedData := proxy.NewCapturedData("") + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("X-API-Key", "wrong-key") + req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) + assert.Equal(t, "header", capturedData.GetAuthMethod()) +} + +func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + return nil, errors.New("gRPC unavailable") + }} + hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("X-API-Key", "some-key") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadGateway, rec.Code) +} + +func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request with header auth. + req1 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req1.Header.Set("X-API-Key", "secret-key") + req1 = req1.WithContext(proxy.WithCapturedData(req1.Context(), proxy.NewCapturedData(""))) + rec1 := httptest.NewRecorder() + handler.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code) + + // Extract session cookie. + var sessionCookie *http.Cookie + for _, c := range rec1.Result().Cookies() { + if c.Name == auth.SessionCookieName { + sessionCookie = c + break + } + } + require.NotNil(t, sessionCookie) + + // Second request with only the session cookie (no header). + capturedData2 := proxy.NewCapturedData("") + req2 := httptest.NewRequest(http.MethodGet, "http://example.com/other", nil) + req2.AddCookie(sessionCookie) + req2 = req2.WithContext(proxy.WithCapturedData(req2.Context(), capturedData2)) + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + + assert.Equal(t, http.StatusOK, rec2.Code) + assert.Equal(t, "header-user", capturedData2.GetUserID()) + assert.Equal(t, "header", capturedData2.GetAuthMethod()) +} diff --git a/proxy/internal/geolocation/download.go b/proxy/internal/geolocation/download.go new file mode 100644 index 000000000..64d515275 --- /dev/null +++ b/proxy/internal/geolocation/download.go @@ -0,0 +1,264 @@ +package geolocation + +import ( + "archive/tar" + "bufio" + "compress/gzip" + "crypto/sha256" + "errors" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz" + mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256" + mmdbInnerName = "GeoLite2-City.mmdb" + + downloadTimeout = 2 * time.Minute + maxMMDBSize = 256 << 20 // 256 MB +) + +// ensureMMDB checks for an existing MMDB file in dataDir. If none is found, +// it downloads from pkgs.netbird.io with SHA256 verification. +func ensureMMDB(logger *log.Logger, dataDir string) (string, error) { + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err) + } + + pattern := filepath.Join(dataDir, mmdbGlob) + if files, _ := filepath.Glob(pattern); len(files) > 0 { + mmdbPath := files[len(files)-1] + logger.Debugf("using existing geolocation database: %s", mmdbPath) + return mmdbPath, nil + } + + logger.Info("geolocation database not found, downloading from pkgs.netbird.io") + return downloadMMDB(logger, dataDir) +} + +func downloadMMDB(logger *log.Logger, dataDir string) (string, error) { + client := &http.Client{Timeout: downloadTimeout} + + datedName, err := fetchRemoteFilename(client, mmdbTarGZURL) + if err != nil { + return "", fmt.Errorf("get remote filename: %w", err) + } + + mmdbFilename := deriveMMDBFilename(datedName) + mmdbPath := filepath.Join(dataDir, mmdbFilename) + + tmp, err := os.MkdirTemp("", "geolite-proxy-*") + if err != nil { + return "", fmt.Errorf("create temp directory: %w", err) + } + defer os.RemoveAll(tmp) + + checksumFile := filepath.Join(tmp, "checksum.sha256") + if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil { + return "", fmt.Errorf("download checksum: %w", err) + } + + expectedHash, err := readChecksumFile(checksumFile) + if err != nil { + return "", fmt.Errorf("read checksum: %w", err) + } + + tarFile := filepath.Join(tmp, datedName) + logger.Debugf("downloading geolocation database (%s)", datedName) + if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil { + return "", fmt.Errorf("download database: %w", err) + } + + if err := verifySHA256(tarFile, expectedHash); err != nil { + return "", fmt.Errorf("verify database checksum: %w", err) + } + + if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil { + return "", fmt.Errorf("extract database: %w", err) + } + + logger.Infof("geolocation database downloaded: %s", mmdbPath) + return mmdbPath, nil +} + +// deriveMMDBFilename converts a tar.gz filename to an MMDB filename. +// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb +func deriveMMDBFilename(tarName string) string { + base, _, _ := strings.Cut(tarName, ".") + if !strings.Contains(base, "_") { + return "GeoLite2-City.mmdb" + } + return base + ".mmdb" +} + +func fetchRemoteFilename(client *http.Client, url string) (string, error) { + resp, err := client.Head(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode) + } + + cd := resp.Header.Get("Content-Disposition") + if cd == "" { + return "", errors.New("no Content-Disposition header") + } + + _, params, err := mime.ParseMediaType(cd) + if err != nil { + return "", fmt.Errorf("parse Content-Disposition: %w", err) + } + + name := filepath.Base(params["filename"]) + if name == "" || name == "." { + return "", errors.New("no filename in Content-Disposition") + } + return name, nil +} + +func downloadToFile(client *http.Client, url, dest string) error { + resp, err := client.Get(url) //nolint:gosec + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + f, err := os.Create(dest) //nolint:gosec + if err != nil { + return err + } + defer f.Close() + + // Cap download at 256 MB to prevent unbounded reads from a compromised server. + if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil { + return err + } + return nil +} + +func readChecksumFile(path string) (string, error) { + f, err := os.Open(path) //nolint:gosec + if err != nil { + return "", err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + if scanner.Scan() { + parts := strings.Fields(scanner.Text()) + if len(parts) > 0 { + return parts[0], nil + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return "", errors.New("empty checksum file") +} + +func verifySHA256(path, expected string) error { + f, err := os.Open(path) //nolint:gosec + if err != nil { + return err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return err + } + + actual := fmt.Sprintf("%x", h.Sum(nil)) + if actual != expected { + return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual) + } + return nil +} + +func extractMMDBFromTarGZ(tarGZPath, destPath string) error { + f, err := os.Open(tarGZPath) //nolint:gosec + if err != nil { + return err + } + defer f.Close() + + gz, err := gzip.NewReader(f) + if err != nil { + return err + } + defer gz.Close() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + + if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName { + if hdr.Size < 0 || hdr.Size > maxMMDBSize { + return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize) + } + if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil { + return err + } + return nil + } + } + + return fmt.Errorf("%s not found in archive", mmdbInnerName) +} + +// extractToFileAtomic writes r to a temporary file in the same directory as +// destPath, then renames it into place so a crash never leaves a truncated file. +func extractToFileAtomic(r io.Reader, destPath string) error { + dir := filepath.Dir(destPath) + tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp") + if err != nil { + return fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmp.Name() + + if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader + if closeErr := tmp.Close(); closeErr != nil { + log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr) + } + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr) + } + return fmt.Errorf("write mmdb: %w", err) + } + if err := tmp.Close(); err != nil { + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr) + } + return fmt.Errorf("close temp file: %w", err) + } + if err := os.Rename(tmpPath, destPath); err != nil { + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr) + } + return fmt.Errorf("rename to %s: %w", destPath, err) + } + return nil +} diff --git a/proxy/internal/geolocation/geolocation.go b/proxy/internal/geolocation/geolocation.go new file mode 100644 index 000000000..81b02efb3 --- /dev/null +++ b/proxy/internal/geolocation/geolocation.go @@ -0,0 +1,152 @@ +// Package geolocation provides IP-to-country lookups using MaxMind GeoLite2 databases. +package geolocation + +import ( + "fmt" + "net/netip" + "os" + "strconv" + "sync" + + "github.com/oschwald/maxminddb-golang" + log "github.com/sirupsen/logrus" +) + +const ( + // EnvDisable disables geolocation lookups entirely when set to a truthy value. + EnvDisable = "NB_PROXY_DISABLE_GEOLOCATION" + + mmdbGlob = "GeoLite2-City_*.mmdb" +) + +type record struct { + Country struct { + ISOCode string `maxminddb:"iso_code"` + } `maxminddb:"country"` + City struct { + Names struct { + En string `maxminddb:"en"` + } `maxminddb:"names"` + } `maxminddb:"city"` + Subdivisions []struct { + ISOCode string `maxminddb:"iso_code"` + Names struct { + En string `maxminddb:"en"` + } `maxminddb:"names"` + } `maxminddb:"subdivisions"` +} + +// Result holds the outcome of a geo lookup. +type Result struct { + CountryCode string + CityName string + SubdivisionCode string + SubdivisionName string +} + +// Lookup provides IP geolocation lookups. +type Lookup struct { + mu sync.RWMutex + db *maxminddb.Reader + logger *log.Logger +} + +// NewLookup opens or downloads the GeoLite2-City MMDB in dataDir. +// Returns nil without error if geolocation is disabled via environment +// variable, no data directory is configured, or the download fails +// (graceful degradation: country restrictions will deny all requests). +func NewLookup(logger *log.Logger, dataDir string) (*Lookup, error) { + if isDisabledByEnv(logger) { + logger.Info("geolocation disabled via environment variable") + return nil, nil //nolint:nilnil + } + + if dataDir == "" { + return nil, nil //nolint:nilnil + } + + mmdbPath, err := ensureMMDB(logger, dataDir) + if err != nil { + logger.Warnf("geolocation database unavailable: %v", err) + logger.Warn("country-based access restrictions will deny all requests until a database is available") + return nil, nil //nolint:nilnil + } + + db, err := maxminddb.Open(mmdbPath) + if err != nil { + return nil, fmt.Errorf("open GeoLite2 database %s: %w", mmdbPath, err) + } + + logger.Infof("geolocation database loaded from %s", mmdbPath) + return &Lookup{db: db, logger: logger}, nil +} + +// LookupAddr returns the country ISO code and city name for the given IP. +// Returns an empty Result if the database is nil or the lookup fails. +func (l *Lookup) LookupAddr(addr netip.Addr) Result { + if l == nil { + return Result{} + } + + l.mu.RLock() + defer l.mu.RUnlock() + + if l.db == nil { + return Result{} + } + + addr = addr.Unmap() + + var rec record + if err := l.db.Lookup(addr.AsSlice(), &rec); err != nil { + l.logger.Debugf("geolocation lookup %s: %v", addr, err) + return Result{} + } + r := Result{ + CountryCode: rec.Country.ISOCode, + CityName: rec.City.Names.En, + } + if len(rec.Subdivisions) > 0 { + r.SubdivisionCode = rec.Subdivisions[0].ISOCode + r.SubdivisionName = rec.Subdivisions[0].Names.En + } + return r +} + +// Available reports whether the lookup has a loaded database. +func (l *Lookup) Available() bool { + if l == nil { + return false + } + l.mu.RLock() + defer l.mu.RUnlock() + return l.db != nil +} + +// Close releases the database resources. +func (l *Lookup) Close() error { + if l == nil { + return nil + } + l.mu.Lock() + defer l.mu.Unlock() + if l.db != nil { + err := l.db.Close() + l.db = nil + return err + } + return nil +} + +func isDisabledByEnv(logger *log.Logger) bool { + val := os.Getenv(EnvDisable) + if val == "" { + return false + } + disabled, err := strconv.ParseBool(val) + if err != nil { + logger.Warnf("parse %s=%q: %v", EnvDisable, val, err) + return false + } + return disabled +} diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index 4a61f6bcf..d3f67dc57 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -11,8 +11,6 @@ import ( type requestContextKey string const ( - serviceIdKey requestContextKey = "serviceId" - accountIdKey requestContextKey = "accountId" capturedDataKey requestContextKey = "capturedData" ) @@ -47,112 +45,117 @@ func (o ResponseOrigin) String() string { // to pass data back up the middleware chain. type CapturedData struct { mu sync.RWMutex - RequestID string - ServiceId types.ServiceID - AccountId types.AccountID - Origin ResponseOrigin - ClientIP netip.Addr - UserID string - AuthMethod string + requestID string + serviceID types.ServiceID + accountID types.AccountID + origin ResponseOrigin + clientIP netip.Addr + userID string + authMethod string } -// GetRequestID safely gets the request ID +// NewCapturedData creates a CapturedData with the given request ID. +func NewCapturedData(requestID string) *CapturedData { + return &CapturedData{requestID: requestID} +} + +// GetRequestID returns the request ID. func (c *CapturedData) GetRequestID() string { c.mu.RLock() defer c.mu.RUnlock() - return c.RequestID + return c.requestID } -// SetServiceId safely sets the service ID -func (c *CapturedData) SetServiceId(serviceId types.ServiceID) { +// SetServiceID sets the service ID. +func (c *CapturedData) SetServiceID(serviceID types.ServiceID) { c.mu.Lock() defer c.mu.Unlock() - c.ServiceId = serviceId + c.serviceID = serviceID } -// GetServiceId safely gets the service ID -func (c *CapturedData) GetServiceId() types.ServiceID { +// GetServiceID returns the service ID. +func (c *CapturedData) GetServiceID() types.ServiceID { c.mu.RLock() defer c.mu.RUnlock() - return c.ServiceId + return c.serviceID } -// SetAccountId safely sets the account ID -func (c *CapturedData) SetAccountId(accountId types.AccountID) { +// SetAccountID sets the account ID. +func (c *CapturedData) SetAccountID(accountID types.AccountID) { c.mu.Lock() defer c.mu.Unlock() - c.AccountId = accountId + c.accountID = accountID } -// GetAccountId safely gets the account ID -func (c *CapturedData) GetAccountId() types.AccountID { +// GetAccountID returns the account ID. +func (c *CapturedData) GetAccountID() types.AccountID { c.mu.RLock() defer c.mu.RUnlock() - return c.AccountId + return c.accountID } -// SetOrigin safely sets the response origin +// SetOrigin sets the response origin. func (c *CapturedData) SetOrigin(origin ResponseOrigin) { c.mu.Lock() defer c.mu.Unlock() - c.Origin = origin + c.origin = origin } -// GetOrigin safely gets the response origin +// GetOrigin returns the response origin. func (c *CapturedData) GetOrigin() ResponseOrigin { c.mu.RLock() defer c.mu.RUnlock() - return c.Origin + return c.origin } -// SetClientIP safely sets the resolved client IP. +// SetClientIP sets the resolved client IP. func (c *CapturedData) SetClientIP(ip netip.Addr) { c.mu.Lock() defer c.mu.Unlock() - c.ClientIP = ip + c.clientIP = ip } -// GetClientIP safely gets the resolved client IP. +// GetClientIP returns the resolved client IP. func (c *CapturedData) GetClientIP() netip.Addr { c.mu.RLock() defer c.mu.RUnlock() - return c.ClientIP + return c.clientIP } -// SetUserID safely sets the authenticated user ID. +// SetUserID sets the authenticated user ID. func (c *CapturedData) SetUserID(userID string) { c.mu.Lock() defer c.mu.Unlock() - c.UserID = userID + c.userID = userID } -// GetUserID safely gets the authenticated user ID. +// GetUserID returns the authenticated user ID. func (c *CapturedData) GetUserID() string { c.mu.RLock() defer c.mu.RUnlock() - return c.UserID + return c.userID } -// SetAuthMethod safely sets the authentication method used. +// SetAuthMethod sets the authentication method used. func (c *CapturedData) SetAuthMethod(method string) { c.mu.Lock() defer c.mu.Unlock() - c.AuthMethod = method + c.authMethod = method } -// GetAuthMethod safely gets the authentication method used. +// GetAuthMethod returns the authentication method used. func (c *CapturedData) GetAuthMethod() string { c.mu.RLock() defer c.mu.RUnlock() - return c.AuthMethod + return c.authMethod } -// WithCapturedData adds a CapturedData struct to the context +// WithCapturedData adds a CapturedData struct to the context. func WithCapturedData(ctx context.Context, data *CapturedData) context.Context { return context.WithValue(ctx, capturedDataKey, data) } -// CapturedDataFromContext retrieves the CapturedData from context +// CapturedDataFromContext retrieves the CapturedData from context. func CapturedDataFromContext(ctx context.Context) *CapturedData { v := ctx.Value(capturedDataKey) data, ok := v.(*CapturedData) @@ -161,28 +164,3 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData { } return data } - -func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context { - return context.WithValue(ctx, serviceIdKey, serviceId) -} - -func ServiceIdFromContext(ctx context.Context) types.ServiceID { - v := ctx.Value(serviceIdKey) - serviceId, ok := v.(types.ServiceID) - if !ok { - return "" - } - return serviceId -} -func withAccountId(ctx context.Context, accountId types.AccountID) context.Context { - return context.WithValue(ctx, accountIdKey, accountId) -} - -func AccountIdFromContext(ctx context.Context) types.AccountID { - v := ctx.Value(accountIdKey) - accountId, ok := v.(types.AccountID) - if !ok { - return "" - } - return accountId -} diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index 1ee9b2a42..246851d24 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -66,19 +66,16 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Set the serviceId in the context for later retrieval. - ctx := withServiceId(r.Context(), result.serviceID) - // Set the accountId in the context for later retrieval (for middleware). - ctx = withAccountId(ctx, result.accountID) - // Set the accountId in the context for the roundtripper to use. + ctx := r.Context() + // Set the account ID in the context for the roundtripper to use. ctx = roundtrip.WithAccountID(ctx, result.accountID) - // Also populate captured data if it exists (allows middleware to read after handler completes). + // Populate captured data if it exists (allows middleware to read after handler completes). // This solves the problem of passing data UP the middleware chain: we put a mutable struct // pointer in the context, and mutate the struct here so outer middleware can read it. if capturedData := CapturedDataFromContext(ctx); capturedData != nil { - capturedData.SetServiceId(result.serviceID) - capturedData.SetAccountId(result.accountID) + capturedData.SetServiceID(result.serviceID) + capturedData.SetAccountID(result.accountID) } pt := result.target @@ -96,10 +93,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { } rp := &httputil.ReverseProxy{ - Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders), + Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders, result.stripAuthHeaders), Transport: p.transport, FlushInterval: -1, - ErrorHandler: proxyErrorHandler, + ErrorHandler: p.proxyErrorHandler, } if result.rewriteRedirects { rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose @@ -113,7 +110,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // When passHostHeader is true, the original client Host header is preserved // instead of being rewritten to the backend's address. // The pathRewrite parameter controls how the request path is transformed. -func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) { +func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string, stripAuthHeaders []string) func(r *httputil.ProxyRequest) { return func(r *httputil.ProxyRequest) { switch pathRewrite { case PathRewritePreserve: @@ -137,6 +134,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost r.Out.Host = target.Host } + for _, h := range stripAuthHeaders { + r.Out.Header.Del(h) + } + for k, v := range customHeaders { r.Out.Header.Set(k, v) } @@ -305,7 +306,7 @@ func extractForwardedPort(host, resolvedProto string) string { // proxyErrorHandler handles errors from the reverse proxy and serves // user-friendly error pages instead of raw error responses. -func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { +func (p *ReverseProxy) proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { if cd := CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(OriginProxyError) } @@ -313,7 +314,7 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { clientIP := getClientIP(r) title, message, code, status := classifyProxyError(err) - log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v", + p.logger.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v", requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err) web.ServeErrorPage(w, r, code, title, message, requestID, status) diff --git a/proxy/internal/proxy/reverseproxy_test.go b/proxy/internal/proxy/reverseproxy_test.go index b05ead198..c53307837 100644 --- a/proxy/internal/proxy/reverseproxy_test.go +++ b/proxy/internal/proxy/reverseproxy_test.go @@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} t.Run("rewrites host to backend by default", func(t *testing.T) { - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345") rewrite(pr) @@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) { }) t.Run("preserves original host when passHostHeader is true", func(t *testing.T) { - rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345") rewrite(pr) @@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) { func TestRewriteFunc_XForwardedForStripping(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080") p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) { pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") @@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000") rewrite(pr) @@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000") rewrite(pr) @@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr.In.TLS = &tls.ConnectionState{} @@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") rewrite(pr) @@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("auto detects https from TLS", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr.In.TLS = &tls.ConnectionState{} @@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("auto detects http without TLS", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") rewrite(pr) @@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("forced proto overrides TLS detection", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "https"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") // No TLS, but forced to https @@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { t.Run("forced http proto", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "http"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr.In.TLS = &tls.ConnectionState{} @@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) { func TestRewriteFunc_SessionCookieStripping(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080") p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) t.Run("strips nb_session cookie", func(t *testing.T) { pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") @@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) { func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080") p := &ReverseProxy{forwardedProto: "auto"} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) t.Run("strips session_token query parameter", func(t *testing.T) { pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000") @@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { t.Run("rewrites URL to target with path prefix", func(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080/app") - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000") rewrite(pr) @@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) { target, _ := url.Parse("https://backend.example.org:443/app") - rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000") rewrite(pr) @@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) { t.Run("strips matched prefix and preserves subpath", func(t *testing.T) { target, _ := url.Parse("https://backend.example.org:443/app") - rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000") rewrite(pr) @@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("appends to X-Forwarded-For", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") @@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("preserves upstream X-Real-IP", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") @@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2") @@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-Host", "original.example.com") @@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-Proto", "https") @@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-Port", "8443") @@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") @@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") @@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1") @@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("empty trusted list behaves as untrusted", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") @@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) { t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} - rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") @@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) { // Management builds: path="/heise", target="https://heise.de:443/heise" target, _ := url.Parse("https://heise.de:443/heise") - rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") rewrite(pr) @@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { t.Run("subpath under prefix also preserved", func(t *testing.T) { target, _ := url.Parse("https://heise.de:443/heise") - rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000") rewrite(pr) @@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { // What the behavior WOULD be if target URL had no path (true stripping) t.Run("target without path prefix gives true stripping", func(t *testing.T) { target, _ := url.Parse("https://heise.de:443") - rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") rewrite(pr) @@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) { target, _ := url.Parse("https://heise.de:443") - rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000") rewrite(pr) @@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) { // Root path "/" — no stripping expected t.Run("root path forwards full request path unchanged", func(t *testing.T) { target, _ := url.Parse("https://backend.example.com:443/") - rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") rewrite(pr) @@ -551,7 +551,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) { target, _ := url.Parse("http://backend.internal:8080") t.Run("preserve keeps full request path", func(t *testing.T) { - rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil) + rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil, nil) pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000") rewrite(pr) @@ -561,7 +561,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) { }) t.Run("preserve with root matchedPath", func(t *testing.T) { - rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil) + rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil, nil) pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000") rewrite(pr) @@ -579,7 +579,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) { "X-Custom-Auth": "token-abc", "X-Env": "production", } - rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers) + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") rewrite(pr) @@ -589,7 +589,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) { }) t.Run("nil customHeaders is fine", func(t *testing.T) { - rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil) + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") rewrite(pr) @@ -599,7 +599,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) { t.Run("custom headers override existing request headers", func(t *testing.T) { headers := map[string]string{"X-Override": "new-value"} - rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers) + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil) pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr.In.Header.Set("X-Override", "old-value") @@ -609,11 +609,38 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) { }) } +func TestRewriteFunc_StripsAuthorizationHeader(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto"} + target, _ := url.Parse("http://backend.internal:8080") + + t.Run("strips incoming Authorization when no custom Authorization set", func(t *testing.T) { + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, []string{"Authorization"}) + pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") + pr.In.Header.Set("Authorization", "Bearer proxy-token") + + rewrite(pr) + + assert.Empty(t, pr.Out.Header.Get("Authorization"), "Authorization should be stripped") + }) + + t.Run("custom Authorization replaces incoming", func(t *testing.T) { + headers := map[string]string{"Authorization": "Basic YmFja2VuZDpzZWNyZXQ="} + rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, []string{"Authorization"}) + pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") + pr.In.Header.Set("Authorization", "Bearer proxy-token") + + rewrite(pr) + + assert.Equal(t, "Basic YmFja2VuZDpzZWNyZXQ=", pr.Out.Header.Get("Authorization"), + "backend Authorization from custom headers should be set") + }) +} + func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) { p := &ReverseProxy{forwardedProto: "auto"} target, _ := url.Parse("http://backend.internal:8080") - rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"}) + rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"}, nil) pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000") rewrite(pr) diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index 1513fbe45..fe470cf01 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -38,6 +38,11 @@ type Mapping struct { Paths map[string]*PathTarget PassHostHeader bool RewriteRedirects bool + // StripAuthHeaders are header names used for header-based auth. + // These headers are stripped from requests before forwarding. + StripAuthHeaders []string + // sortedPaths caches the paths sorted by length (longest first). + sortedPaths []string } type targetResult struct { @@ -47,6 +52,7 @@ type targetResult struct { accountID types.AccountID passHostHeader bool rewriteRedirects bool + stripAuthHeaders []string } func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) { @@ -65,16 +71,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo return targetResult{}, false } - // Sort paths by length (longest first) in a naive attempt to match the most specific route first. - paths := make([]string, 0, len(m.Paths)) - for path := range m.Paths { - paths = append(paths, path) - } - sort.Slice(paths, func(i, j int) bool { - return len(paths[i]) > len(paths[j]) - }) - - for _, path := range paths { + for _, path := range m.sortedPaths { if strings.HasPrefix(req.URL.Path, path) { pt := m.Paths[path] if pt == nil || pt.URL == nil { @@ -89,6 +86,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo accountID: m.AccountID, passHostHeader: m.PassHostHeader, rewriteRedirects: m.RewriteRedirects, + stripAuthHeaders: m.StripAuthHeaders, }, true } } @@ -96,7 +94,18 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo return targetResult{}, false } +// AddMapping registers a host-to-backend mapping for the reverse proxy. func (p *ReverseProxy) AddMapping(m Mapping) { + // Sort paths longest-first to match the most specific route first. + paths := make([]string, 0, len(m.Paths)) + for path := range m.Paths { + paths = append(paths, path) + } + sort.Slice(paths, func(i, j int) bool { + return len(paths[i]) > len(paths[j]) + }) + m.sortedPaths = paths + p.mappingsMux.Lock() defer p.mappingsMux.Unlock() p.mappings[m.Host] = m diff --git a/proxy/internal/restrict/restrict.go b/proxy/internal/restrict/restrict.go new file mode 100644 index 000000000..a0d99ce93 --- /dev/null +++ b/proxy/internal/restrict/restrict.go @@ -0,0 +1,183 @@ +// Package restrict provides connection-level access control based on +// IP CIDR ranges and geolocation (country codes). +package restrict + +import ( + "net/netip" + "slices" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/geolocation" +) + +// GeoResolver resolves an IP address to geographic information. +type GeoResolver interface { + LookupAddr(addr netip.Addr) geolocation.Result + Available() bool +} + +// Filter evaluates IP restrictions. CIDR checks are performed first +// (cheap), followed by country lookups (more expensive) only when needed. +type Filter struct { + AllowedCIDRs []netip.Prefix + BlockedCIDRs []netip.Prefix + AllowedCountries []string + BlockedCountries []string +} + +// ParseFilter builds a Filter from the raw string slices. Returns nil +// if all slices are empty. +func ParseFilter(allowedCIDRs, blockedCIDRs, allowedCountries, blockedCountries []string) *Filter { + if len(allowedCIDRs) == 0 && len(blockedCIDRs) == 0 && + len(allowedCountries) == 0 && len(blockedCountries) == 0 { + return nil + } + + f := &Filter{ + AllowedCountries: normalizeCountryCodes(allowedCountries), + BlockedCountries: normalizeCountryCodes(blockedCountries), + } + for _, cidr := range allowedCIDRs { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + log.Warnf("skip invalid allowed CIDR %q: %v", cidr, err) + continue + } + f.AllowedCIDRs = append(f.AllowedCIDRs, prefix.Masked()) + } + for _, cidr := range blockedCIDRs { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + log.Warnf("skip invalid blocked CIDR %q: %v", cidr, err) + continue + } + f.BlockedCIDRs = append(f.BlockedCIDRs, prefix.Masked()) + } + return f +} + +func normalizeCountryCodes(codes []string) []string { + if len(codes) == 0 { + return nil + } + out := make([]string, len(codes)) + for i, c := range codes { + out[i] = strings.ToUpper(c) + } + return out +} + +// Verdict is the result of an access check. +type Verdict int + +const ( + // Allow indicates the address passed all checks. + Allow Verdict = iota + // DenyCIDR indicates the address was blocked by a CIDR rule. + DenyCIDR + // DenyCountry indicates the address was blocked by a country rule. + DenyCountry + // DenyGeoUnavailable indicates that country restrictions are configured + // but the geo lookup is unavailable. + DenyGeoUnavailable +) + +// String returns the deny reason string matching the HTTP auth mechanism names. +func (v Verdict) String() string { + switch v { + case Allow: + return "allow" + case DenyCIDR: + return "ip_restricted" + case DenyCountry: + return "country_restricted" + case DenyGeoUnavailable: + return "geo_unavailable" + default: + return "unknown" + } +} + +// Check evaluates whether addr is permitted. CIDR rules are evaluated +// first because they are O(n) prefix comparisons. Country rules run +// only when CIDR checks pass and require a geo lookup. +func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict { + if f == nil { + return Allow + } + + // Normalize v4-mapped-v6 (e.g. ::ffff:10.1.2.3) to plain v4 so that + // IPv4 CIDR rules match regardless of how the address was received. + addr = addr.Unmap() + + if v := f.checkCIDR(addr); v != Allow { + return v + } + return f.checkCountry(addr, geo) +} + +func (f *Filter) checkCIDR(addr netip.Addr) Verdict { + if len(f.AllowedCIDRs) > 0 { + allowed := false + for _, prefix := range f.AllowedCIDRs { + if prefix.Contains(addr) { + allowed = true + break + } + } + if !allowed { + return DenyCIDR + } + } + + for _, prefix := range f.BlockedCIDRs { + if prefix.Contains(addr) { + return DenyCIDR + } + } + return Allow +} + +func (f *Filter) checkCountry(addr netip.Addr, geo GeoResolver) Verdict { + if len(f.AllowedCountries) == 0 && len(f.BlockedCountries) == 0 { + return Allow + } + + if geo == nil || !geo.Available() { + return DenyGeoUnavailable + } + + result := geo.LookupAddr(addr) + if result.CountryCode == "" { + // Unknown country: deny if an allowlist is active, allow otherwise. + // Blocklists are best-effort: unknown countries pass through since + // the default policy is allow. + if len(f.AllowedCountries) > 0 { + return DenyCountry + } + return Allow + } + + if len(f.AllowedCountries) > 0 { + if !slices.Contains(f.AllowedCountries, result.CountryCode) { + return DenyCountry + } + } + + if slices.Contains(f.BlockedCountries, result.CountryCode) { + return DenyCountry + } + + return Allow +} + +// HasRestrictions returns true if any restriction rules are configured. +func (f *Filter) HasRestrictions() bool { + if f == nil { + return false + } + return len(f.AllowedCIDRs) > 0 || len(f.BlockedCIDRs) > 0 || + len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0 +} diff --git a/proxy/internal/restrict/restrict_test.go b/proxy/internal/restrict/restrict_test.go new file mode 100644 index 000000000..17a5848d8 --- /dev/null +++ b/proxy/internal/restrict/restrict_test.go @@ -0,0 +1,278 @@ +package restrict + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/proxy/internal/geolocation" +) + +type mockGeo struct { + countries map[string]string +} + +func (m *mockGeo) LookupAddr(addr netip.Addr) geolocation.Result { + return geolocation.Result{CountryCode: m.countries[addr.String()]} +} + +func (m *mockGeo) Available() bool { return true } + +func newMockGeo(entries map[string]string) *mockGeo { + return &mockGeo{countries: entries} +} + +func TestFilter_Check_NilFilter(t *testing.T) { + var f *Filter + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) +} + +func TestFilter_Check_AllowedCIDR(t *testing.T) { + f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil)) +} + +func TestFilter_Check_BlockedCIDR(t *testing.T) { + f := ParseFilter(nil, []string{"10.0.0.0/8"}, nil, nil) + + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("192.168.1.1"), nil)) +} + +func TestFilter_Check_AllowedAndBlockedCIDR(t *testing.T) { + f := ParseFilter([]string{"10.0.0.0/8"}, []string{"10.1.0.0/16"}, nil, nil) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), nil), "allowed by allowlist, not in blocklist") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "allowed by allowlist but in blocklist") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "not in allowlist") +} + +func TestFilter_Check_AllowedCountry(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "US", + "2.2.2.2": "DE", + "3.3.3.3": "CN", + }) + f := ParseFilter(nil, nil, []string{"US", "DE"}, nil) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US in allowlist") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE in allowlist") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist") +} + +func TestFilter_Check_BlockedCountry(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "CN", + "2.2.2.2": "RU", + "3.3.3.3": "US", + }) + f := ParseFilter(nil, nil, nil, []string{"CN", "RU"}) + + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "CN in blocklist") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "RU in blocklist") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "US not in blocklist") +} + +func TestFilter_Check_AllowedAndBlockedCountry(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "US", + "2.2.2.2": "DE", + "3.3.3.3": "CN", + }) + // Allow US and DE, but block DE explicitly. + f := ParseFilter(nil, nil, []string{"US", "DE"}, []string{"DE"}) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US allowed and not blocked") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE allowed but also blocked, block wins") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist") +} + +func TestFilter_Check_UnknownCountryWithAllowlist(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "US", + }) + f := ParseFilter(nil, nil, []string{"US"}, nil) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known US in allowlist") + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country denied when allowlist is active") +} + +func TestFilter_Check_UnknownCountryWithBlocklistOnly(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "CN", + }) + f := ParseFilter(nil, nil, nil, []string{"CN"}) + + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known CN in blocklist") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country allowed when only blocklist is active") +} + +func TestFilter_Check_CountryWithoutGeo(t *testing.T) { + f := ParseFilter(nil, nil, []string{"US"}, nil) + assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country allowlist") +} + +func TestFilter_Check_CountryBlocklistWithoutGeo(t *testing.T) { + f := ParseFilter(nil, nil, nil, []string{"CN"}) + assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country blocklist") +} + +func TestFilter_Check_GeoUnavailable(t *testing.T) { + geo := &unavailableGeo{} + + f := ParseFilter(nil, nil, []string{"US"}, nil) + assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country allowlist") + + f2 := ParseFilter(nil, nil, nil, []string{"CN"}) + assert.Equal(t, DenyGeoUnavailable, f2.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country blocklist") +} + +func TestFilter_Check_CIDROnlySkipsGeo(t *testing.T) { + f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + + // CIDR-only filter should never touch geo, so nil geo is fine. + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil)) + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil)) +} + +func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) { + geo := newMockGeo(map[string]string{ + "10.1.2.3": "CN", + "10.2.3.4": "US", + }) + f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, []string{"CN"}) + + assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("10.1.2.3"), geo), "CIDR allowed but country blocked") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), geo), "CIDR allowed and country not blocked") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), geo), "CIDR denied before country check") +} + +func TestParseFilter_Empty(t *testing.T) { + f := ParseFilter(nil, nil, nil, nil) + assert.Nil(t, f) +} + +func TestParseFilter_InvalidCIDR(t *testing.T) { + f := ParseFilter([]string{"invalid", "10.0.0.0/8"}, nil, nil, nil) + + assert.NotNil(t, f) + assert.Len(t, f.AllowedCIDRs, 1, "invalid CIDR should be skipped") + assert.Equal(t, netip.MustParsePrefix("10.0.0.0/8"), f.AllowedCIDRs[0]) +} + +func TestFilter_HasRestrictions(t *testing.T) { + assert.False(t, (*Filter)(nil).HasRestrictions()) + assert.False(t, (&Filter{}).HasRestrictions()) + assert.True(t, ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil).HasRestrictions()) + assert.True(t, ParseFilter(nil, nil, []string{"US"}, nil).HasRestrictions()) +} + +func TestFilter_Check_IPv6CIDR(t *testing.T) { + f := ParseFilter([]string{"2001:db8::/32"}, nil, nil, nil) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 addr in v6 allowlist") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("2001:db9::1"), nil), "v6 addr not in v6 allowlist") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 addr not in v6 allowlist") +} + +func TestFilter_Check_IPv4MappedIPv6(t *testing.T) { + f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + + // A v4-mapped-v6 address like ::ffff:10.1.2.3 must match a v4 CIDR. + v4mapped := netip.MustParseAddr("::ffff:10.1.2.3") + assert.True(t, v4mapped.Is4In6(), "precondition: address is v4-in-v6") + assert.Equal(t, Allow, f.Check(v4mapped, nil), "v4-mapped-v6 must match v4 CIDR after Unmap") + + v4mappedOutside := netip.MustParseAddr("::ffff:192.168.1.1") + assert.Equal(t, DenyCIDR, f.Check(v4mappedOutside, nil), "v4-mapped-v6 outside v4 CIDR") +} + +func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) { + f := ParseFilter([]string{"10.0.0.0/8", "2001:db8::/32"}, nil, nil, nil) + + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 in v4 CIDR") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 in v6 CIDR") + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("::ffff:10.1.2.3"), nil), "v4-mapped matches v4 CIDR") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "v4 not in either CIDR") + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("fe80::1"), nil), "v6 not in either CIDR") +} + +func TestParseFilter_CanonicalizesNonMaskedCIDR(t *testing.T) { + // 1.1.1.1/24 has host bits set; ParseFilter should canonicalize to 1.1.1.0/24. + f := ParseFilter([]string{"1.1.1.1/24"}, nil, nil, nil) + assert.Equal(t, netip.MustParsePrefix("1.1.1.0/24"), f.AllowedCIDRs[0]) + + // Verify it still matches correctly. + assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.100"), nil)) + assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("1.1.2.1"), nil)) +} + +func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) { + geo := newMockGeo(map[string]string{ + "1.1.1.1": "US", + "2.2.2.2": "DE", + "3.3.3.3": "CN", + }) + + tests := []struct { + name string + allowedCountries []string + blockedCountries []string + addr string + want Verdict + }{ + { + name: "lowercase allowlist matches uppercase MaxMind code", + allowedCountries: []string{"us", "de"}, + addr: "1.1.1.1", + want: Allow, + }, + { + name: "mixed-case allowlist matches", + allowedCountries: []string{"Us", "dE"}, + addr: "2.2.2.2", + want: Allow, + }, + { + name: "lowercase allowlist rejects non-matching country", + allowedCountries: []string{"us", "de"}, + addr: "3.3.3.3", + want: DenyCountry, + }, + { + name: "lowercase blocklist blocks matching country", + blockedCountries: []string{"cn"}, + addr: "3.3.3.3", + want: DenyCountry, + }, + { + name: "mixed-case blocklist blocks matching country", + blockedCountries: []string{"Cn"}, + addr: "3.3.3.3", + want: DenyCountry, + }, + { + name: "lowercase blocklist does not block non-matching country", + blockedCountries: []string{"cn"}, + addr: "1.1.1.1", + want: Allow, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + f := ParseFilter(nil, nil, tc.allowedCountries, tc.blockedCountries) + got := f.Check(netip.MustParseAddr(tc.addr), geo) + assert.Equal(t, tc.want, got) + }) + } +} + +// unavailableGeo simulates a GeoResolver whose database is not loaded. +type unavailableGeo struct{} + +func (u *unavailableGeo) LookupAddr(_ netip.Addr) geolocation.Result { return geolocation.Result{} } +func (u *unavailableGeo) Available() bool { return false } diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go index 84fde0731..8255c36d3 100644 --- a/proxy/internal/tcp/router.go +++ b/proxy/internal/tcp/router.go @@ -7,12 +7,14 @@ import ( "net" "net/netip" "slices" + "strings" "sync" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/proxy/internal/accesslog" + "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/types" ) @@ -20,6 +22,10 @@ import ( // timeout is configured. const defaultDialTimeout = 30 * time.Second +// errAccessRestricted is returned by relayTCP for access restriction +// denials so callers can skip warn-level logging (already logged at debug). +var errAccessRestricted = errors.New("rejected by access restrictions") + // SNIHost is a typed key for SNI hostname lookups. type SNIHost string @@ -64,6 +70,11 @@ type Route struct { // DialTimeout overrides the default dial timeout for this route. // Zero uses defaultDialTimeout. DialTimeout time.Duration + // SessionIdleTimeout overrides the default idle timeout for relay connections. + // Zero uses DefaultIdleTimeout. + SessionIdleTimeout time.Duration + // Filter holds connection-level IP/geo restrictions. Nil means no restrictions. + Filter *restrict.Filter } // l4Logger sends layer-4 access log entries to the management server. @@ -99,6 +110,7 @@ type Router struct { drainDone chan struct{} observer RelayObserver accessLog l4Logger + geo restrict.GeoResolver // svcCtxs tracks a context per service ID. All relay goroutines for a // service derive from its context; canceling it kills them immediately. svcCtxs map[types.ServiceID]context.Context @@ -144,6 +156,7 @@ func (r *Router) HTTPListener() net.Listener { // stored and resolved by priority at lookup time (HTTP > TCP). // Empty host is ignored to prevent conflicts with ECH/ESNI fallback. func (r *Router) AddRoute(host SNIHost, route Route) { + host = SNIHost(strings.ToLower(string(host))) if host == "" { return } @@ -166,6 +179,8 @@ func (r *Router) AddRoute(host SNIHost, route Route) { // Active relay connections for the service are closed immediately. // If other routes remain for the host, they are preserved. func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) { + host = SNIHost(strings.ToLower(string(host))) + r.mu.Lock() defer r.mu.Unlock() @@ -295,7 +310,7 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) { return } - host := SNIHost(sni) + host := SNIHost(strings.ToLower(sni)) route, ok := r.lookupRoute(host) if !ok { r.handleUnmatched(ctx, wrapped) @@ -308,11 +323,13 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) { } if err := r.relayTCP(ctx, wrapped, host, route); err != nil { - r.logger.WithFields(log.Fields{ - "sni": host, - "service_id": route.ServiceID, - "target": route.Target, - }).Warnf("TCP relay: %v", err) + if !errors.Is(err, errAccessRestricted) { + r.logger.WithFields(log.Fields{ + "sni": host, + "service_id": route.ServiceID, + "target": route.Target, + }).Warnf("TCP relay: %v", err) + } _ = wrapped.Close() } } @@ -336,10 +353,12 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) { if fb != nil { if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil { - r.logger.WithFields(log.Fields{ - "service_id": fb.ServiceID, - "target": fb.Target, - }).Warnf("TCP relay (fallback): %v", err) + if !errors.Is(err, errAccessRestricted) { + r.logger.WithFields(log.Fields{ + "service_id": fb.ServiceID, + "target": fb.Target, + }).Warnf("TCP relay (fallback): %v", err) + } _ = conn.Close() } return @@ -427,10 +446,44 @@ func (r *Router) cancelServiceLocked(svcID types.ServiceID) { } } +// SetGeo sets the geolocation lookup used for country-based restrictions. +func (r *Router) SetGeo(geo restrict.GeoResolver) { + r.mu.Lock() + defer r.mu.Unlock() + r.geo = geo +} + +// checkRestrictions evaluates the route's access filter against the +// connection's remote address. Returns Allow if the connection is +// permitted, or a deny verdict indicating the reason. +func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict { + if route.Filter == nil { + return restrict.Allow + } + + addr, err := addrFromConn(conn) + if err != nil { + r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr()) + return restrict.DenyCIDR + } + + r.mu.RLock() + geo := r.geo + r.mu.RUnlock() + + return route.Filter.Check(addr, geo) +} + // relayTCP sets up and runs a bidirectional TCP relay. // The caller owns conn and must close it if this method returns an error. // On success (nil error), both conn and backend are closed by the relay. func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error { + if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow { + r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict) + r.logL4Deny(route, conn, verdict) + return errAccessRestricted + } + svcCtx, err := r.acquireRelay(ctx, route) if err != nil { return err @@ -468,8 +521,13 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route }) entry.Debug("TCP relay started") + idleTimeout := route.SessionIdleTimeout + if idleTimeout <= 0 { + idleTimeout = DefaultIdleTimeout + } + start := time.Now() - s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout) + s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout) elapsed := time.Since(start) if obs != nil { @@ -537,12 +595,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, return } - var sourceIP netip.Addr - if remote := conn.RemoteAddr(); remote != nil { - if ap, err := netip.ParseAddrPort(remote.String()); err == nil { - sourceIP = ap.Addr().Unmap() - } - } + sourceIP, _ := addrFromConn(conn) al.LogL4(accesslog.L4Entry{ AccountID: route.AccountID, @@ -556,6 +609,28 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, }) } +// logL4Deny sends an access log entry for a denied connection. +func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) { + r.mu.RLock() + al := r.accessLog + r.mu.RUnlock() + + if al == nil { + return + } + + sourceIP, _ := addrFromConn(conn) + + al.LogL4(accesslog.L4Entry{ + AccountID: route.AccountID, + ServiceID: route.ServiceID, + Protocol: route.Protocol, + Host: route.Domain, + SourceIP: sourceIP, + DenyReason: verdict.String(), + }) +} + // getOrCreateServiceCtxLocked returns the context for a service, creating one // if it doesn't exist yet. The context is a child of the server context. // Must be called with mu held. @@ -568,3 +643,16 @@ func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types r.svcCancels[svcID] = cancel return ctx } + +// addrFromConn extracts a netip.Addr from a connection's remote address. +func addrFromConn(conn net.Conn) (netip.Addr, error) { + remote := conn.RemoteAddr() + if remote == nil { + return netip.Addr{}, errors.New("no remote address") + } + ap, err := netip.ParseAddrPort(remote.String()) + if err != nil { + return netip.Addr{}, err + } + return ap.Addr().Unmap(), nil +} diff --git a/proxy/internal/tcp/router_test.go b/proxy/internal/tcp/router_test.go index 0e2cfe3e1..189cdc622 100644 --- a/proxy/internal/tcp/router_test.go +++ b/proxy/internal/tcp/router_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/types" ) @@ -1668,3 +1669,73 @@ func startEchoPlain(t *testing.T) net.Listener { return ln } + +// fakeAddr implements net.Addr with a custom string representation. +type fakeAddr string + +func (f fakeAddr) Network() string { return "tcp" } +func (f fakeAddr) String() string { return string(f) } + +// fakeConn is a minimal net.Conn with a controllable RemoteAddr. +type fakeConn struct { + net.Conn + remote net.Addr +} + +func (f *fakeConn) RemoteAddr() net.Addr { return f.remote } + +func TestCheckRestrictions_UnparseableAddress(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + route := Route{Filter: filter} + + conn := &fakeConn{remote: fakeAddr("not-an-ip")} + assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "unparsable address must be denied") +} + +func TestCheckRestrictions_NilRemoteAddr(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + route := Route{Filter: filter} + + conn := &fakeConn{remote: nil} + assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "nil remote address must be denied") +} + +func TestCheckRestrictions_AllowedAndDenied(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + route := Route{Filter: filter} + + allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}} + assert.Equal(t, restrict.Allow, router.checkRestrictions(allowed, route), "10.1.2.3 in allowlist") + + denied := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}} + assert.NotEqual(t, restrict.Allow, router.checkRestrictions(denied, route), "192.168.1.1 not in allowlist") +} + +func TestCheckRestrictions_NilFilter(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + route := Route{Filter: nil} + + conn := &fakeConn{remote: fakeAddr("not-an-ip")} + assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "nil filter should allow everything") +} + +func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) { + router := NewPortRouter(log.StandardLogger(), nil) + filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil) + route := Route{Filter: filter} + + // net.IPv4() returns a 16-byte v4-in-v6 representation internally. + // The restriction check must Unmap it to match the v4 CIDR. + conn := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 5678}} + assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "v4-in-v6 TCPAddr must match v4 CIDR") + + // Explicitly v4-mapped-v6 address string. + conn6 := &fakeConn{remote: fakeAddr("[::ffff:10.1.2.3]:5678")} + assert.Equal(t, restrict.Allow, router.checkRestrictions(conn6, route), "::ffff:10.1.2.3 must match v4 CIDR") + + connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")} + assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR") +} diff --git a/proxy/internal/udp/relay.go b/proxy/internal/udp/relay.go index f2f58e858..d20ecf48b 100644 --- a/proxy/internal/udp/relay.go +++ b/proxy/internal/udp/relay.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/netutil" + "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/types" ) @@ -67,6 +68,8 @@ type Relay struct { dialTimeout time.Duration sessionTTL time.Duration maxSessions int + filter *restrict.Filter + geo restrict.GeoResolver mu sync.RWMutex sessions map[clientAddr]*session @@ -114,6 +117,10 @@ type RelayConfig struct { SessionTTL time.Duration MaxSessions int AccessLog l4Logger + // Filter holds connection-level IP/geo restrictions. Nil means no restrictions. + Filter *restrict.Filter + // Geo is the geolocation lookup used for country-based restrictions. + Geo restrict.GeoResolver } // New creates a UDP relay for the given listener and backend target. @@ -146,6 +153,8 @@ func New(parentCtx context.Context, cfg RelayConfig) *Relay { dialTimeout: dialTimeout, sessionTTL: sessionTTL, maxSessions: maxSessions, + filter: cfg.Filter, + geo: cfg.Geo, sessions: make(map[clientAddr]*session), bufPool: sync.Pool{ New: func() any { @@ -166,9 +175,18 @@ func (r *Relay) ServiceID() types.ServiceID { // SetObserver sets the session lifecycle observer. Must be called before Serve. func (r *Relay) SetObserver(obs SessionObserver) { + r.mu.Lock() + defer r.mu.Unlock() r.observer = obs } +// getObserver returns the current session lifecycle observer. +func (r *Relay) getObserver() SessionObserver { + r.mu.RLock() + defer r.mu.RUnlock() + return r.observer +} + // Serve starts the relay loop. It blocks until the context is canceled // or the listener is closed. func (r *Relay) Serve() { @@ -209,8 +227,8 @@ func (r *Relay) Serve() { } sess.bytesIn.Add(int64(nw)) - if r.observer != nil { - r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw) + if obs := r.getObserver(); obs != nil { + obs.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw) } r.bufPool.Put(bufp) } @@ -234,6 +252,10 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { return nil, r.ctx.Err() } + if err := r.checkAccessRestrictions(addr); err != nil { + return nil, err + } + r.mu.Lock() if sess, ok = r.sessions[key]; ok && sess != nil { @@ -248,16 +270,16 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { if len(r.sessions) >= r.maxSessions { r.mu.Unlock() - if r.observer != nil { - r.observer.UDPSessionRejected(r.accountID) + if obs := r.getObserver(); obs != nil { + obs.UDPSessionRejected(r.accountID) } return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions) } if !r.sessLimiter.Allow() { r.mu.Unlock() - if r.observer != nil { - r.observer.UDPSessionRejected(r.accountID) + if obs := r.getObserver(); obs != nil { + obs.UDPSessionRejected(r.accountID) } return nil, fmt.Errorf("session creation rate limited") } @@ -274,8 +296,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { r.mu.Lock() delete(r.sessions, key) r.mu.Unlock() - if r.observer != nil { - r.observer.UDPSessionDialError(r.accountID) + if obs := r.getObserver(); obs != nil { + obs.UDPSessionDialError(r.accountID) } return nil, fmt.Errorf("dial backend %s: %w", r.target, err) } @@ -293,8 +315,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { r.sessions[key] = sess r.mu.Unlock() - if r.observer != nil { - r.observer.UDPSessionStarted(r.accountID) + if obs := r.getObserver(); obs != nil { + obs.UDPSessionStarted(r.accountID) } r.sessWg.Go(func() { @@ -305,6 +327,21 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) { return sess, nil } +func (r *Relay) checkAccessRestrictions(addr net.Addr) error { + if r.filter == nil { + return nil + } + clientIP, err := addrFromUDPAddr(addr) + if err != nil { + return fmt.Errorf("parse client address %s for restriction check: %w", addr, err) + } + if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow { + r.logDeny(clientIP, v) + return fmt.Errorf("access restricted for %s", addr) + } + return nil +} + // relayBackendToClient reads packets from the backend and writes them // back to the client through the public-facing listener. func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) { @@ -332,8 +369,8 @@ func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) { } sess.bytesOut.Add(int64(nw)) - if r.observer != nil { - r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw) + if obs := r.getObserver(); obs != nil { + obs.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw) } } } @@ -402,9 +439,10 @@ func (r *Relay) cleanupIdleSessions() { } r.mu.Unlock() + obs := r.getObserver() for _, sess := range expired { - if r.observer != nil { - r.observer.UDPSessionEnded(r.accountID) + if obs != nil { + obs.UDPSessionEnded(r.accountID) } r.logSessionEnd(sess) } @@ -429,8 +467,8 @@ func (r *Relay) removeSession(sess *session) { if removed { r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)", sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) - if r.observer != nil { - r.observer.UDPSessionEnded(r.accountID) + if obs := r.getObserver(); obs != nil { + obs.UDPSessionEnded(r.accountID) } r.logSessionEnd(sess) } @@ -459,6 +497,22 @@ func (r *Relay) logSessionEnd(sess *session) { }) } +// logDeny sends an access log entry for a denied UDP packet. +func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict) { + if r.accessLog == nil { + return + } + + r.accessLog.LogL4(accesslog.L4Entry{ + AccountID: r.accountID, + ServiceID: r.serviceID, + Protocol: accesslog.ProtocolUDP, + Host: r.domain, + SourceIP: clientIP, + DenyReason: verdict.String(), + }) +} + // Close stops the relay, waits for all session goroutines to exit, // and cleans up remaining sessions. func (r *Relay) Close() { @@ -485,12 +539,22 @@ func (r *Relay) Close() { } r.mu.Unlock() + obs := r.getObserver() for _, sess := range closedSessions { - if r.observer != nil { - r.observer.UDPSessionEnded(r.accountID) + if obs != nil { + obs.UDPSessionEnded(r.accountID) } r.logSessionEnd(sess) } r.sessWg.Wait() } + +// addrFromUDPAddr extracts a netip.Addr from a net.Addr. +func addrFromUDPAddr(addr net.Addr) (netip.Addr, error) { + ap, err := netip.ParseAddrPort(addr.String()) + if err != nil { + return netip.Addr{}, err + } + return ap.Addr().Unmap(), nil +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index ebecfc6f6..8af151446 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -490,7 +490,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T logger := log.New() logger.SetLevel(log.WarnLevel) - authMw := auth.NewMiddleware(logger, nil) + authMw := auth.NewMiddleware(logger, nil, nil) proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger) clusterAddress := "test.proxy.io" @@ -511,6 +511,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T 0, proxytypes.AccountID(mapping.GetAccountId()), proxytypes.ServiceID(mapping.GetId()), + nil, ) require.NoError(t, err) diff --git a/proxy/server.go b/proxy/server.go index 649d49c9a..c4d12859b 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -43,12 +43,14 @@ import ( "github.com/netbirdio/netbird/proxy/internal/certwatch" "github.com/netbirdio/netbird/proxy/internal/conntrack" "github.com/netbirdio/netbird/proxy/internal/debug" + "github.com/netbirdio/netbird/proxy/internal/geolocation" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/k8s" proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" "github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/roundtrip" nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" "github.com/netbirdio/netbird/proxy/internal/types" @@ -59,7 +61,6 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) - // portRouter bundles a per-port Router with its listener and cancel func. type portRouter struct { router *nbtcp.Router @@ -95,6 +96,9 @@ type Server struct { // so they can be closed during graceful shutdown, since http.Server.Shutdown // does not handle them. hijackTracker conntrack.HijackTracker + // geo resolves IP addresses to country/city for access restrictions and access logs. + geo restrict.GeoResolver + geoRaw *geolocation.Lookup // routerReady is closed once mainRouter is fully initialized. // The mapping worker waits on this before processing updates. @@ -159,10 +163,38 @@ type Server struct { // SupportsCustomPorts indicates whether the proxy can bind arbitrary // ports for TCP/UDP/TLS services. SupportsCustomPorts bool - // DefaultDialTimeout is the default timeout for establishing backend - // connections when no per-service timeout is configured. Zero means - // each transport uses its own hardcoded default (typically 30s). - DefaultDialTimeout time.Duration + // MaxDialTimeout caps the per-service backend dial timeout. + // When the API sends a timeout, it is clamped to this value. + // When the API sends no timeout, this value is used as the default. + // Zero means no cap (the proxy honors whatever management sends). + MaxDialTimeout time.Duration + // GeoDataDir is the directory containing GeoLite2 MMDB files for + // country-based access restrictions. Empty disables geo lookups. + GeoDataDir string + // MaxSessionIdleTimeout caps the per-service session idle timeout. + // Zero means no cap (the proxy honors whatever management sends). + // Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments. + MaxSessionIdleTimeout time.Duration +} + +// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured. +func (s *Server) clampIdleTimeout(d time.Duration) time.Duration { + if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout { + return s.MaxSessionIdleTimeout + } + return d +} + +// clampDialTimeout returns d capped to MaxDialTimeout when configured. +// If d is zero, MaxDialTimeout is used as the default. +func (s *Server) clampDialTimeout(d time.Duration) time.Duration { + if s.MaxDialTimeout <= 0 { + return d + } + if d <= 0 || d > s.MaxDialTimeout { + return s.MaxDialTimeout + } + return d } // NotifyStatus sends a status update to management about tunnel connectivity. @@ -226,7 +258,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) runCtx, runCancel := context.WithCancel(ctx) defer runCancel() - go s.newManagementMappingWorker(runCtx, s.mgmtClient) // Initialize the netbird client, this is required to build peer connections // to proxy over. @@ -236,6 +267,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { PreSharedKey: s.PreSharedKey, }, s.Logger, s, s.mgmtClient) + // Create health checker before the mapping worker so it can track + // management connectivity from the first stream connection. + s.healthChecker = health.NewChecker(s.Logger, s.netbird) + + go s.newManagementMappingWorker(runCtx, s.mgmtClient) + tlsConfig, err := s.configureTLS(ctx) if err != nil { return err @@ -244,14 +281,33 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger) + geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir) + if err != nil { + return fmt.Errorf("initialize geolocation: %w", err) + } + s.geoRaw = geoLookup + if geoLookup != nil { + s.geo = geoLookup + } + + var startupOK bool + defer func() { + if startupOK { + return + } + if s.geoRaw != nil { + if err := s.geoRaw.Close(); err != nil { + s.Logger.Debugf("close geolocation on startup failure: %v", err) + } + } + }() + // Configure the authentication middleware with session validator for OIDC group checks. - s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) + s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo) // Configure Access logs to management server. s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) - s.healthChecker = health.NewChecker(s.Logger, s.netbird) - s.startDebugEndpoint() if err := s.startHealthServer(); err != nil { @@ -294,6 +350,8 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), } + startupOK = true + httpsErr := make(chan error, 1) go func() { s.Logger.Debug("starting HTTPS server on SNI router HTTP channel") @@ -691,6 +749,16 @@ func (s *Server) shutdownServices() { s.portRouterWg.Wait() wg.Wait() + + if s.accessLog != nil { + s.accessLog.Close() + } + + if s.geoRaw != nil { + if err := s.geoRaw.Close(); err != nil { + s.Logger.Debugf("close geolocation: %v", err) + } + } } // resolveDialFunc returns a DialContextFunc that dials through the @@ -1073,15 +1141,20 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin return fmt.Errorf("router for TCP port %d: %w", port, err) } + s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) + + router.SetGeo(s.geo) router.SetFallback(nbtcp.Route{ - Type: nbtcp.RouteTCP, - AccountID: accountID, - ServiceID: svcID, - Domain: mapping.GetDomain(), - Protocol: accesslog.ProtocolTCP, - Target: targetAddr, - ProxyProtocol: s.l4ProxyProtocol(mapping), - DialTimeout: s.l4DialTimeout(mapping), + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTCP, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), + Filter: parseRestrictions(mapping), }) s.portMu.Lock() @@ -1108,6 +1181,8 @@ func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMappin return fmt.Errorf("empty target address for UDP service %s", svcID) } + s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) + if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil { return fmt.Errorf("UDP relay for service %s: %w", svcID, err) } @@ -1141,15 +1216,20 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin return fmt.Errorf("router for TLS port %d: %w", tlsPort, err) } + s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) + + router.SetGeo(s.geo) router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ - Type: nbtcp.RouteTCP, - AccountID: accountID, - ServiceID: svcID, - Domain: mapping.GetDomain(), - Protocol: accesslog.ProtocolTLS, - Target: targetAddr, - ProxyProtocol: s.l4ProxyProtocol(mapping), - DialTimeout: s.l4DialTimeout(mapping), + Type: nbtcp.RouteTCP, + AccountID: accountID, + ServiceID: svcID, + Domain: mapping.GetDomain(), + Protocol: accesslog.ProtocolTLS, + Target: targetAddr, + ProxyProtocol: s.l4ProxyProtocol(mapping), + DialTimeout: s.l4DialTimeout(mapping), + SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), + Filter: parseRestrictions(mapping), }) if tlsPort != s.mainPort { @@ -1181,6 +1261,32 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser } } +// parseRestrictions converts a proto mapping's access restrictions into +// a restrict.Filter. Returns nil if the mapping has no restrictions. +func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter { + r := mapping.GetAccessRestrictions() + if r == nil { + return nil + } + return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries()) +} + +// warnIfGeoUnavailable logs a warning if the mapping has country restrictions +// but the proxy has no geolocation database loaded. All requests to this +// service will be denied at runtime (fail-close). +func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) { + if r == nil { + return + } + if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 { + return + } + if s.geo != nil && s.geo.Available() { + return + } + s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain) +} + // l4TargetAddress extracts and validates the target address from a mapping's // first path entry. Returns empty string if no paths exist or the address is // not a valid host:port. @@ -1210,15 +1316,15 @@ func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool { } // l4DialTimeout returns the dial timeout from the first target's options, -// falling back to the server's DefaultDialTimeout. +// clamped to MaxDialTimeout. func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration { paths := mapping.GetPath() if len(paths) > 0 { if d := paths[0].GetOptions().GetRequestTimeout(); d != nil { - return d.AsDuration() + return s.clampDialTimeout(d.AsDuration()) } } - return s.DefaultDialTimeout + return s.clampDialTimeout(0) } // l4SessionIdleTimeout returns the configured session idle timeout from the @@ -1254,7 +1360,9 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t dialFn, err := s.resolveDialFunc(accountID) if err != nil { - _ = listener.Close() + if err := listener.Close(); err != nil { + s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err) + } return fmt.Errorf("resolve dialer for UDP: %w", err) } @@ -1273,8 +1381,10 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t ServiceID: svcID, DialFunc: dialFn, DialTimeout: s.l4DialTimeout(mapping), - SessionTTL: l4SessionIdleTimeout(mapping), + SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), AccessLog: s.accessLog, + Filter: parseRestrictions(mapping), + Geo: s.geo, }) relay.SetObserver(s.meter) @@ -1306,9 +1416,15 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) if mapping.GetAuth().GetOidc() { schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto)) } + for _, ha := range mapping.GetAuth().GetHeaderAuths() { + schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader())) + } + + ipRestrictions := parseRestrictions(mapping) + s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second - if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil { + if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil { return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err) } m := s.protoToMapping(ctx, mapping) @@ -1449,12 +1565,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping pt.RequestTimeout = d.AsDuration() } } - if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 { - pt.RequestTimeout = s.DefaultDialTimeout - } + pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout) paths[pathMapping.GetPath()] = pt } - return proxy.Mapping{ + m := proxy.Mapping{ ID: types.ServiceID(mapping.GetId()), AccountID: types.AccountID(mapping.GetAccountId()), Host: mapping.GetDomain(), @@ -1462,6 +1576,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping PassHostHeader: mapping.GetPassHostHeader(), RewriteRedirects: mapping.GetRewriteRedirects(), } + for _, ha := range mapping.GetAuth().GetHeaderAuths() { + m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader()) + } + return m } func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode { diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 4b851bf19..66f39b92f 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -2826,6 +2826,10 @@ components: type: string description: "City name from geolocation" example: "San Francisco" + subdivision_code: + type: string + description: "First-level administrative subdivision ISO code (e.g. state/province)" + example: "CA" bytes_upload: type: integer format: int64 @@ -2952,26 +2956,32 @@ components: id: type: string description: Service ID + example: "cs8i4ug6lnn4g9hqv7mg" name: type: string description: Service name + example: "myapp.example.netbird.app" domain: type: string description: Domain for the service + example: "myapp.example.netbird.app" mode: type: string description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. enum: [http, tcp, udp, tls] default: http + example: "http" listen_port: type: integer minimum: 0 maximum: 65535 description: Port the proxy listens on (L4/TLS only) + example: 8443 port_auto_assigned: type: boolean description: Whether the listen port was auto-assigned readOnly: true + example: false proxy_cluster: type: string description: The proxy cluster handling this service (derived from domain) @@ -2984,14 +2994,19 @@ components: enabled: type: boolean description: Whether the service is enabled + example: true pass_host_header: type: boolean description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address + example: false rewrite_redirects: type: boolean description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain + example: false auth: $ref: '#/components/schemas/ServiceAuthConfig' + access_restrictions: + $ref: '#/components/schemas/AccessRestrictions' meta: $ref: '#/components/schemas/ServiceMeta' required: @@ -3035,19 +3050,23 @@ components: name: type: string description: Service name + example: "myapp.example.netbird.app" domain: type: string description: Domain for the service + example: "myapp.example.netbird.app" mode: type: string description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. enum: [http, tcp, udp, tls] default: http + example: "http" listen_port: type: integer minimum: 0 maximum: 65535 description: Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. + example: 5432 targets: type: array items: @@ -3057,14 +3076,19 @@ components: type: boolean description: Whether the service is enabled default: true + example: true pass_host_header: type: boolean description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address + example: false rewrite_redirects: type: boolean description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain + example: false auth: $ref: '#/components/schemas/ServiceAuthConfig' + access_restrictions: + $ref: '#/components/schemas/AccessRestrictions' required: - name - domain @@ -3075,13 +3099,16 @@ components: skip_tls_verify: type: boolean description: Skip TLS certificate verification for this backend + example: false request_timeout: type: string description: Per-target response timeout as a Go duration string (e.g. "30s", "2m") + example: "30s" path_rewrite: type: string description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. enum: [preserve] + example: "preserve" custom_headers: type: object description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected. @@ -3091,40 +3118,50 @@ components: additionalProperties: type: string pattern: '^[^\r\n]*$' + example: {"X-Custom-Header": "value"} proxy_protocol: type: boolean description: Send PROXY Protocol v2 header to this backend (TCP/TLS only) + example: false session_idle_timeout: type: string - description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. + description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). + example: "2m" ServiceTarget: type: object properties: target_id: type: string description: Target ID + example: "cs8i4ug6lnn4g9hqv7mg" target_type: type: string description: Target type enum: [peer, host, domain, subnet] + example: "subnet" path: type: string description: URL path prefix for this target (HTTP only) + example: "/" protocol: type: string description: Protocol to use when connecting to the backend enum: [http, https, tcp, udp] + example: "http" host: type: string description: Backend ip or domain for this target + example: "10.10.0.1" port: type: integer minimum: 1 maximum: 65535 description: Backend port for this target + example: 8080 enabled: type: boolean description: Whether this target is enabled + example: true options: $ref: '#/components/schemas/ServiceTargetOptions' required: @@ -3144,15 +3181,73 @@ components: $ref: '#/components/schemas/BearerAuthConfig' link_auth: $ref: '#/components/schemas/LinkAuthConfig' + header_auths: + type: array + items: + $ref: '#/components/schemas/HeaderAuthConfig' + HeaderAuthConfig: + type: object + description: Static header-value authentication. The proxy checks that the named header matches the configured value. + properties: + enabled: + type: boolean + description: Whether header auth is enabled + example: true + header: + type: string + description: HTTP header name to check (e.g. "Authorization", "X-API-Key") + example: "X-API-Key" + value: + type: string + description: Expected header value. For Basic auth use "Basic base64(user:pass)". For Bearer use "Bearer token". Cleared in responses. + example: "my-secret-api-key" + required: + - enabled + - header + - value + AccessRestrictions: + type: object + description: Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. + properties: + allowed_cidrs: + type: array + items: + type: string + format: cidr + example: "192.168.1.0/24" + description: CIDR allowlist. If non-empty, only IPs matching these CIDRs are allowed. + blocked_cidrs: + type: array + items: + type: string + format: cidr + example: "10.0.0.0/8" + description: CIDR blocklist. Connections from these CIDRs are rejected. Evaluated after allowed_cidrs. + allowed_countries: + type: array + items: + type: string + pattern: '^[a-zA-Z]{2}$' + example: "US" + description: ISO 3166-1 alpha-2 country codes to allow. If non-empty, only these countries are permitted. + blocked_countries: + type: array + items: + type: string + pattern: '^[a-zA-Z]{2}$' + example: "DE" + description: ISO 3166-1 alpha-2 country codes to block. PasswordAuthConfig: type: object properties: enabled: type: boolean description: Whether password auth is enabled + example: true password: type: string description: Auth password + example: "s3cret" required: - enabled - password @@ -3162,9 +3257,11 @@ components: enabled: type: boolean description: Whether PIN auth is enabled + example: false pin: type: string description: PIN value + example: "1234" required: - enabled - pin @@ -3174,10 +3271,12 @@ components: enabled: type: boolean description: Whether bearer auth is enabled + example: true distribution_groups: type: array items: type: string + example: "ch8i4ug6lnn4g9hqv7mg" description: List of group IDs that can use bearer auth required: - enabled @@ -3187,6 +3286,7 @@ components: enabled: type: boolean description: Whether link auth is enabled + example: false required: - enabled ProxyCluster: @@ -3217,20 +3317,25 @@ components: id: type: string description: Domain ID + example: "ds8i4ug6lnn4g9hqv7mg" domain: type: string description: Domain name + example: "example.netbird.app" validated: type: boolean description: Whether the domain has been validated + example: true type: $ref: '#/components/schemas/ReverseProxyDomainType' target_cluster: type: string description: The proxy cluster this domain is validated against (only for custom domains) + example: "eu.proxy.netbird.io" supports_custom_ports: type: boolean description: Whether the cluster supports binding arbitrary TCP/UDP ports + example: true required: - id - domain @@ -3242,9 +3347,11 @@ components: domain: type: string description: Domain name + example: "myapp.example.com" target_cluster: type: string description: The proxy cluster this domain should be validated against + example: "eu.proxy.netbird.io" required: - domain - target_cluster diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 4ec3b871a..693449d14 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1276,6 +1276,21 @@ func (e PutApiIntegrationsMspTenantsIdInviteJSONBodyValue) Valid() bool { } } +// AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. +type AccessRestrictions struct { + // AllowedCidrs CIDR allowlist. If non-empty, only IPs matching these CIDRs are allowed. + AllowedCidrs *[]string `json:"allowed_cidrs,omitempty"` + + // AllowedCountries ISO 3166-1 alpha-2 country codes to allow. If non-empty, only these countries are permitted. + AllowedCountries *[]string `json:"allowed_countries,omitempty"` + + // BlockedCidrs CIDR blocklist. Connections from these CIDRs are rejected. Evaluated after allowed_cidrs. + BlockedCidrs *[]string `json:"blocked_cidrs,omitempty"` + + // BlockedCountries ISO 3166-1 alpha-2 country codes to block. + BlockedCountries *[]string `json:"blocked_countries,omitempty"` +} + // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { // CityName Commonly used English name of the city @@ -1988,6 +2003,18 @@ type GroupRequest struct { Resources *[]Resource `json:"resources,omitempty"` } +// HeaderAuthConfig Static header-value authentication. The proxy checks that the named header matches the configured value. +type HeaderAuthConfig struct { + // Enabled Whether header auth is enabled + Enabled bool `json:"enabled"` + + // Header HTTP header name to check (e.g. "Authorization", "X-API-Key") + Header string `json:"header"` + + // Value Expected header value. For Basic auth use "Basic base64(user:pass)". For Bearer use "Bearer token". Cleared in responses. + Value string `json:"value"` +} + // HuntressMatchAttributes Attribute conditions to match when approving agents type HuntressMatchAttributes struct { // DefenderPolicyStatus Policy status of Defender AV for Managed Antivirus. @@ -3324,6 +3351,9 @@ type ProxyAccessLog struct { // StatusCode HTTP status code returned StatusCode int `json:"status_code"` + // SubdivisionCode First-level administrative subdivision ISO code (e.g. state/province) + SubdivisionCode *string `json:"subdivision_code,omitempty"` + // Timestamp Timestamp when the request was made Timestamp time.Time `json:"timestamp"` @@ -3562,7 +3592,9 @@ type SentinelOneMatchAttributesNetworkStatus string // Service defines model for Service. type Service struct { - Auth ServiceAuthConfig `json:"auth"` + // AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. + AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"` + Auth ServiceAuthConfig `json:"auth"` // Domain Domain for the service Domain string `json:"domain"` @@ -3605,6 +3637,7 @@ type ServiceMode string // ServiceAuthConfig defines model for ServiceAuthConfig. type ServiceAuthConfig struct { BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty"` + HeaderAuths *[]HeaderAuthConfig `json:"header_auths,omitempty"` LinkAuth *LinkAuthConfig `json:"link_auth,omitempty"` PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty"` PinAuth *PINAuthConfig `json:"pin_auth,omitempty"` @@ -3627,7 +3660,9 @@ type ServiceMetaStatus string // ServiceRequest defines model for ServiceRequest. type ServiceRequest struct { - Auth *ServiceAuthConfig `json:"auth,omitempty"` + // AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. + AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"` + Auth *ServiceAuthConfig `json:"auth,omitempty"` // Domain Domain for the service Domain string `json:"domain"` @@ -3702,7 +3737,7 @@ type ServiceTargetOptions struct { // RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m") RequestTimeout *string `json:"request_timeout,omitempty"` - // SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. + // SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). SessionIdleTimeout *string `json:"session_idle_timeout,omitempty"` // SkipTlsVerify Skip TLS certificate verification for this backend diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 115ac5101..e5a2d6a98 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.26.0 +// protoc-gen-go v1.36.6 // protoc v6.33.3 // source: proxy_service.proto @@ -13,6 +13,7 @@ import ( timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -177,21 +178,18 @@ func (ProxyStatus) EnumDescriptor() ([]byte, []int) { // ProxyCapabilities describes what a proxy can handle. type ProxyCapabilities struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - + state protoimpl.MessageState `protogen:"open.v1"` // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. SupportsCustomPorts *bool `protobuf:"varint,1,opt,name=supports_custom_ports,json=supportsCustomPorts,proto3,oneof" json:"supports_custom_ports,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ProxyCapabilities) Reset() { *x = ProxyCapabilities{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ProxyCapabilities) String() string { @@ -202,7 +200,7 @@ func (*ProxyCapabilities) ProtoMessage() {} func (x *ProxyCapabilities) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -226,24 +224,21 @@ func (x *ProxyCapabilities) GetSupportsCustomPorts() bool { // GetMappingUpdateRequest is sent to initialise a mapping stream. type GetMappingUpdateRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` + Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` + Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` unknownFields protoimpl.UnknownFields - - ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` - Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` - StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` - Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` - Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` + sizeCache protoimpl.SizeCache } func (x *GetMappingUpdateRequest) Reset() { *x = GetMappingUpdateRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetMappingUpdateRequest) String() string { @@ -254,7 +249,7 @@ func (*GetMappingUpdateRequest) ProtoMessage() {} func (x *GetMappingUpdateRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -308,23 +303,20 @@ func (x *GetMappingUpdateRequest) GetCapabilities() *ProxyCapabilities { // No mappings may be sent to test the liveness of the Proxy. // Mappings that are sent should be interpreted by the Proxy appropriately. type GetMappingUpdateResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Mapping []*ProxyMapping `protobuf:"bytes,1,rep,name=mapping,proto3" json:"mapping,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + Mapping []*ProxyMapping `protobuf:"bytes,1,rep,name=mapping,proto3" json:"mapping,omitempty"` // initial_sync_complete is set on the last message of the initial snapshot. // The proxy uses this to signal that startup is complete. InitialSyncComplete bool `protobuf:"varint,2,opt,name=initial_sync_complete,json=initialSyncComplete,proto3" json:"initial_sync_complete,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GetMappingUpdateResponse) Reset() { *x = GetMappingUpdateResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetMappingUpdateResponse) String() string { @@ -335,7 +327,7 @@ func (*GetMappingUpdateResponse) ProtoMessage() {} func (x *GetMappingUpdateResponse) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -365,27 +357,24 @@ func (x *GetMappingUpdateResponse) GetInitialSyncComplete() bool { } type PathTargetOptions struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - SkipTlsVerify bool `protobuf:"varint,1,opt,name=skip_tls_verify,json=skipTlsVerify,proto3" json:"skip_tls_verify,omitempty"` - RequestTimeout *durationpb.Duration `protobuf:"bytes,2,opt,name=request_timeout,json=requestTimeout,proto3" json:"request_timeout,omitempty"` - PathRewrite PathRewriteMode `protobuf:"varint,3,opt,name=path_rewrite,json=pathRewrite,proto3,enum=management.PathRewriteMode" json:"path_rewrite,omitempty"` - CustomHeaders map[string]string `protobuf:"bytes,4,rep,name=custom_headers,json=customHeaders,proto3" json:"custom_headers,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + state protoimpl.MessageState `protogen:"open.v1"` + SkipTlsVerify bool `protobuf:"varint,1,opt,name=skip_tls_verify,json=skipTlsVerify,proto3" json:"skip_tls_verify,omitempty"` + RequestTimeout *durationpb.Duration `protobuf:"bytes,2,opt,name=request_timeout,json=requestTimeout,proto3" json:"request_timeout,omitempty"` + PathRewrite PathRewriteMode `protobuf:"varint,3,opt,name=path_rewrite,json=pathRewrite,proto3,enum=management.PathRewriteMode" json:"path_rewrite,omitempty"` + CustomHeaders map[string]string `protobuf:"bytes,4,rep,name=custom_headers,json=customHeaders,proto3" json:"custom_headers,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Send PROXY protocol v2 header to this backend. ProxyProtocol bool `protobuf:"varint,5,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` // Idle timeout before a UDP session is reaped. SessionIdleTimeout *durationpb.Duration `protobuf:"bytes,6,opt,name=session_idle_timeout,json=sessionIdleTimeout,proto3" json:"session_idle_timeout,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *PathTargetOptions) Reset() { *x = PathTargetOptions{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *PathTargetOptions) String() string { @@ -396,7 +385,7 @@ func (*PathTargetOptions) ProtoMessage() {} func (x *PathTargetOptions) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -454,22 +443,19 @@ func (x *PathTargetOptions) GetSessionIdleTimeout() *durationpb.Duration { } type PathMapping struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Target string `protobuf:"bytes,2,opt,name=target,proto3" json:"target,omitempty"` + Options *PathTargetOptions `protobuf:"bytes,3,opt,name=options,proto3" json:"options,omitempty"` unknownFields protoimpl.UnknownFields - - Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` - Target string `protobuf:"bytes,2,opt,name=target,proto3" json:"target,omitempty"` - Options *PathTargetOptions `protobuf:"bytes,3,opt,name=options,proto3" json:"options,omitempty"` + sizeCache protoimpl.SizeCache } func (x *PathMapping) Reset() { *x = PathMapping{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *PathMapping) String() string { @@ -480,7 +466,7 @@ func (*PathMapping) ProtoMessage() {} func (x *PathMapping) ProtoReflect() protoreflect.Message { mi := &file_proxy_service_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -516,25 +502,77 @@ func (x *PathMapping) GetOptions() *PathTargetOptions { return nil } -type Authentication struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache +type HeaderAuth struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Header name to check, e.g. "Authorization", "X-API-Key". + Header string `protobuf:"bytes,1,opt,name=header,proto3" json:"header,omitempty"` + // argon2id hash of the expected full header value. + HashedValue string `protobuf:"bytes,2,opt,name=hashed_value,json=hashedValue,proto3" json:"hashed_value,omitempty"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} - SessionKey string `protobuf:"bytes,1,opt,name=session_key,json=sessionKey,proto3" json:"session_key,omitempty"` - MaxSessionAgeSeconds int64 `protobuf:"varint,2,opt,name=max_session_age_seconds,json=maxSessionAgeSeconds,proto3" json:"max_session_age_seconds,omitempty"` - Password bool `protobuf:"varint,3,opt,name=password,proto3" json:"password,omitempty"` - Pin bool `protobuf:"varint,4,opt,name=pin,proto3" json:"pin,omitempty"` - Oidc bool `protobuf:"varint,5,opt,name=oidc,proto3" json:"oidc,omitempty"` +func (x *HeaderAuth) Reset() { + *x = HeaderAuth{} + mi := &file_proxy_service_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HeaderAuth) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HeaderAuth) ProtoMessage() {} + +func (x *HeaderAuth) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HeaderAuth.ProtoReflect.Descriptor instead. +func (*HeaderAuth) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{5} +} + +func (x *HeaderAuth) GetHeader() string { + if x != nil { + return x.Header + } + return "" +} + +func (x *HeaderAuth) GetHashedValue() string { + if x != nil { + return x.HashedValue + } + return "" +} + +type Authentication struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionKey string `protobuf:"bytes,1,opt,name=session_key,json=sessionKey,proto3" json:"session_key,omitempty"` + MaxSessionAgeSeconds int64 `protobuf:"varint,2,opt,name=max_session_age_seconds,json=maxSessionAgeSeconds,proto3" json:"max_session_age_seconds,omitempty"` + Password bool `protobuf:"varint,3,opt,name=password,proto3" json:"password,omitempty"` + Pin bool `protobuf:"varint,4,opt,name=pin,proto3" json:"pin,omitempty"` + Oidc bool `protobuf:"varint,5,opt,name=oidc,proto3" json:"oidc,omitempty"` + HeaderAuths []*HeaderAuth `protobuf:"bytes,6,rep,name=header_auths,json=headerAuths,proto3" json:"header_auths,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *Authentication) Reset() { *x = Authentication{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Authentication) String() string { @@ -544,8 +582,8 @@ func (x *Authentication) String() string { func (*Authentication) ProtoMessage() {} func (x *Authentication) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[6] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -557,7 +595,7 @@ func (x *Authentication) ProtoReflect() protoreflect.Message { // Deprecated: Use Authentication.ProtoReflect.Descriptor instead. func (*Authentication) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{5} + return file_proxy_service_proto_rawDescGZIP(), []int{6} } func (x *Authentication) GetSessionKey() string { @@ -595,11 +633,83 @@ func (x *Authentication) GetOidc() bool { return false } -type ProxyMapping struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields +func (x *Authentication) GetHeaderAuths() []*HeaderAuth { + if x != nil { + return x.HeaderAuths + } + return nil +} +type AccessRestrictions struct { + state protoimpl.MessageState `protogen:"open.v1"` + AllowedCidrs []string `protobuf:"bytes,1,rep,name=allowed_cidrs,json=allowedCidrs,proto3" json:"allowed_cidrs,omitempty"` + BlockedCidrs []string `protobuf:"bytes,2,rep,name=blocked_cidrs,json=blockedCidrs,proto3" json:"blocked_cidrs,omitempty"` + AllowedCountries []string `protobuf:"bytes,3,rep,name=allowed_countries,json=allowedCountries,proto3" json:"allowed_countries,omitempty"` + BlockedCountries []string `protobuf:"bytes,4,rep,name=blocked_countries,json=blockedCountries,proto3" json:"blocked_countries,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AccessRestrictions) Reset() { + *x = AccessRestrictions{} + mi := &file_proxy_service_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AccessRestrictions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AccessRestrictions) ProtoMessage() {} + +func (x *AccessRestrictions) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AccessRestrictions.ProtoReflect.Descriptor instead. +func (*AccessRestrictions) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{7} +} + +func (x *AccessRestrictions) GetAllowedCidrs() []string { + if x != nil { + return x.AllowedCidrs + } + return nil +} + +func (x *AccessRestrictions) GetBlockedCidrs() []string { + if x != nil { + return x.BlockedCidrs + } + return nil +} + +func (x *AccessRestrictions) GetAllowedCountries() []string { + if x != nil { + return x.AllowedCountries + } + return nil +} + +func (x *AccessRestrictions) GetBlockedCountries() []string { + if x != nil { + return x.BlockedCountries + } + return nil +} + +type ProxyMapping struct { + state protoimpl.MessageState `protogen:"open.v1"` Type ProxyMappingUpdateType `protobuf:"varint,1,opt,name=type,proto3,enum=management.ProxyMappingUpdateType" json:"type,omitempty"` Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` AccountId string `protobuf:"bytes,3,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` @@ -616,16 +726,17 @@ type ProxyMapping struct { // Service mode: "http", "tcp", "udp", or "tls". Mode string `protobuf:"bytes,10,opt,name=mode,proto3" json:"mode,omitempty"` // For L4/TLS: the port the proxy listens on. - ListenPort int32 `protobuf:"varint,11,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` + ListenPort int32 `protobuf:"varint,11,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` + AccessRestrictions *AccessRestrictions `protobuf:"bytes,12,opt,name=access_restrictions,json=accessRestrictions,proto3" json:"access_restrictions,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ProxyMapping) Reset() { *x = ProxyMapping{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ProxyMapping) String() string { @@ -635,8 +746,8 @@ func (x *ProxyMapping) String() string { func (*ProxyMapping) ProtoMessage() {} func (x *ProxyMapping) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[6] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[8] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -648,7 +759,7 @@ func (x *ProxyMapping) ProtoReflect() protoreflect.Message { // Deprecated: Use ProxyMapping.ProtoReflect.Descriptor instead. func (*ProxyMapping) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{6} + return file_proxy_service_proto_rawDescGZIP(), []int{8} } func (x *ProxyMapping) GetType() ProxyMappingUpdateType { @@ -728,22 +839,26 @@ func (x *ProxyMapping) GetListenPort() int32 { return 0 } +func (x *ProxyMapping) GetAccessRestrictions() *AccessRestrictions { + if x != nil { + return x.AccessRestrictions + } + return nil +} + // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. type SendAccessLogRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Log *AccessLog `protobuf:"bytes,1,opt,name=log,proto3" json:"log,omitempty"` unknownFields protoimpl.UnknownFields - - Log *AccessLog `protobuf:"bytes,1,opt,name=log,proto3" json:"log,omitempty"` + sizeCache protoimpl.SizeCache } func (x *SendAccessLogRequest) Reset() { *x = SendAccessLogRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *SendAccessLogRequest) String() string { @@ -753,8 +868,8 @@ func (x *SendAccessLogRequest) String() string { func (*SendAccessLogRequest) ProtoMessage() {} func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[7] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[9] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -766,7 +881,7 @@ func (x *SendAccessLogRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogRequest.ProtoReflect.Descriptor instead. func (*SendAccessLogRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{7} + return file_proxy_service_proto_rawDescGZIP(), []int{9} } func (x *SendAccessLogRequest) GetLog() *AccessLog { @@ -778,18 +893,16 @@ func (x *SendAccessLogRequest) GetLog() *AccessLog { // SendAccessLogResponse is intentionally empty to allow for future expansion. type SendAccessLogResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SendAccessLogResponse) Reset() { *x = SendAccessLogResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *SendAccessLogResponse) String() string { @@ -799,8 +912,8 @@ func (x *SendAccessLogResponse) String() string { func (*SendAccessLogResponse) ProtoMessage() {} func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[8] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[10] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -812,14 +925,11 @@ func (x *SendAccessLogResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendAccessLogResponse.ProtoReflect.Descriptor instead. func (*SendAccessLogResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{8} + return file_proxy_service_proto_rawDescGZIP(), []int{10} } type AccessLog struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - + state protoimpl.MessageState `protogen:"open.v1"` Timestamp *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` LogId string `protobuf:"bytes,2,opt,name=log_id,json=logId,proto3" json:"log_id,omitempty"` AccountId string `protobuf:"bytes,3,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` @@ -836,15 +946,15 @@ type AccessLog struct { BytesUpload int64 `protobuf:"varint,14,opt,name=bytes_upload,json=bytesUpload,proto3" json:"bytes_upload,omitempty"` BytesDownload int64 `protobuf:"varint,15,opt,name=bytes_download,json=bytesDownload,proto3" json:"bytes_download,omitempty"` Protocol string `protobuf:"bytes,16,opt,name=protocol,proto3" json:"protocol,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *AccessLog) Reset() { *x = AccessLog{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[9] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *AccessLog) String() string { @@ -854,8 +964,8 @@ func (x *AccessLog) String() string { func (*AccessLog) ProtoMessage() {} func (x *AccessLog) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[9] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[11] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -867,7 +977,7 @@ func (x *AccessLog) ProtoReflect() protoreflect.Message { // Deprecated: Use AccessLog.ProtoReflect.Descriptor instead. func (*AccessLog) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{9} + return file_proxy_service_proto_rawDescGZIP(), []int{11} } func (x *AccessLog) GetTimestamp() *timestamppb.Timestamp { @@ -983,26 +1093,24 @@ func (x *AccessLog) GetProtocol() string { } type AuthenticateRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` - AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` - // Types that are assignable to Request: + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + // Types that are valid to be assigned to Request: // // *AuthenticateRequest_Password // *AuthenticateRequest_Pin - Request isAuthenticateRequest_Request `protobuf_oneof:"request"` + // *AuthenticateRequest_HeaderAuth + Request isAuthenticateRequest_Request `protobuf_oneof:"request"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *AuthenticateRequest) Reset() { *x = AuthenticateRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[10] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *AuthenticateRequest) String() string { @@ -1012,8 +1120,8 @@ func (x *AuthenticateRequest) String() string { func (*AuthenticateRequest) ProtoMessage() {} func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[10] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[12] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1025,7 +1133,7 @@ func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateRequest.ProtoReflect.Descriptor instead. func (*AuthenticateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{10} + return file_proxy_service_proto_rawDescGZIP(), []int{12} } func (x *AuthenticateRequest) GetId() string { @@ -1042,23 +1150,36 @@ func (x *AuthenticateRequest) GetAccountId() string { return "" } -func (m *AuthenticateRequest) GetRequest() isAuthenticateRequest_Request { - if m != nil { - return m.Request +func (x *AuthenticateRequest) GetRequest() isAuthenticateRequest_Request { + if x != nil { + return x.Request } return nil } func (x *AuthenticateRequest) GetPassword() *PasswordRequest { - if x, ok := x.GetRequest().(*AuthenticateRequest_Password); ok { - return x.Password + if x != nil { + if x, ok := x.Request.(*AuthenticateRequest_Password); ok { + return x.Password + } } return nil } func (x *AuthenticateRequest) GetPin() *PinRequest { - if x, ok := x.GetRequest().(*AuthenticateRequest_Pin); ok { - return x.Pin + if x != nil { + if x, ok := x.Request.(*AuthenticateRequest_Pin); ok { + return x.Pin + } + } + return nil +} + +func (x *AuthenticateRequest) GetHeaderAuth() *HeaderAuthRequest { + if x != nil { + if x, ok := x.Request.(*AuthenticateRequest_HeaderAuth); ok { + return x.HeaderAuth + } } return nil } @@ -1075,25 +1196,80 @@ type AuthenticateRequest_Pin struct { Pin *PinRequest `protobuf:"bytes,4,opt,name=pin,proto3,oneof"` } +type AuthenticateRequest_HeaderAuth struct { + HeaderAuth *HeaderAuthRequest `protobuf:"bytes,5,opt,name=header_auth,json=headerAuth,proto3,oneof"` +} + func (*AuthenticateRequest_Password) isAuthenticateRequest_Request() {} func (*AuthenticateRequest_Pin) isAuthenticateRequest_Request() {} -type PasswordRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields +func (*AuthenticateRequest_HeaderAuth) isAuthenticateRequest_Request() {} - Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` +type HeaderAuthRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + HeaderValue string `protobuf:"bytes,1,opt,name=header_value,json=headerValue,proto3" json:"header_value,omitempty"` + HeaderName string `protobuf:"bytes,2,opt,name=header_name,json=headerName,proto3" json:"header_name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HeaderAuthRequest) Reset() { + *x = HeaderAuthRequest{} + mi := &file_proxy_service_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HeaderAuthRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HeaderAuthRequest) ProtoMessage() {} + +func (x *HeaderAuthRequest) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HeaderAuthRequest.ProtoReflect.Descriptor instead. +func (*HeaderAuthRequest) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{13} +} + +func (x *HeaderAuthRequest) GetHeaderValue() string { + if x != nil { + return x.HeaderValue + } + return "" +} + +func (x *HeaderAuthRequest) GetHeaderName() string { + if x != nil { + return x.HeaderName + } + return "" +} + +type PasswordRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *PasswordRequest) Reset() { *x = PasswordRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[11] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *PasswordRequest) String() string { @@ -1103,8 +1279,8 @@ func (x *PasswordRequest) String() string { func (*PasswordRequest) ProtoMessage() {} func (x *PasswordRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[11] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[14] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1116,7 +1292,7 @@ func (x *PasswordRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PasswordRequest.ProtoReflect.Descriptor instead. func (*PasswordRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{11} + return file_proxy_service_proto_rawDescGZIP(), []int{14} } func (x *PasswordRequest) GetPassword() string { @@ -1127,20 +1303,17 @@ func (x *PasswordRequest) GetPassword() string { } type PinRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Pin string `protobuf:"bytes,1,opt,name=pin,proto3" json:"pin,omitempty"` unknownFields protoimpl.UnknownFields - - Pin string `protobuf:"bytes,1,opt,name=pin,proto3" json:"pin,omitempty"` + sizeCache protoimpl.SizeCache } func (x *PinRequest) Reset() { *x = PinRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[12] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *PinRequest) String() string { @@ -1150,8 +1323,8 @@ func (x *PinRequest) String() string { func (*PinRequest) ProtoMessage() {} func (x *PinRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[12] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[15] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1163,7 +1336,7 @@ func (x *PinRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PinRequest.ProtoReflect.Descriptor instead. func (*PinRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{12} + return file_proxy_service_proto_rawDescGZIP(), []int{15} } func (x *PinRequest) GetPin() string { @@ -1174,21 +1347,18 @@ func (x *PinRequest) GetPin() string { } type AuthenticateResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` unknownFields protoimpl.UnknownFields - - Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` - SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` + sizeCache protoimpl.SizeCache } func (x *AuthenticateResponse) Reset() { *x = AuthenticateResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[13] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *AuthenticateResponse) String() string { @@ -1198,8 +1368,8 @@ func (x *AuthenticateResponse) String() string { func (*AuthenticateResponse) ProtoMessage() {} func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[13] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[16] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1211,7 +1381,7 @@ func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthenticateResponse.ProtoReflect.Descriptor instead. func (*AuthenticateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{13} + return file_proxy_service_proto_rawDescGZIP(), []int{16} } func (x *AuthenticateResponse) GetSuccess() bool { @@ -1230,24 +1400,21 @@ func (x *AuthenticateResponse) GetSessionToken() string { // SendStatusUpdateRequest is sent by the proxy to update its status type SendStatusUpdateRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - ServiceId string `protobuf:"bytes,1,opt,name=service_id,json=serviceId,proto3" json:"service_id,omitempty"` - AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` - Status ProxyStatus `protobuf:"varint,3,opt,name=status,proto3,enum=management.ProxyStatus" json:"status,omitempty"` - CertificateIssued bool `protobuf:"varint,4,opt,name=certificate_issued,json=certificateIssued,proto3" json:"certificate_issued,omitempty"` - ErrorMessage *string `protobuf:"bytes,5,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + ServiceId string `protobuf:"bytes,1,opt,name=service_id,json=serviceId,proto3" json:"service_id,omitempty"` + AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + Status ProxyStatus `protobuf:"varint,3,opt,name=status,proto3,enum=management.ProxyStatus" json:"status,omitempty"` + CertificateIssued bool `protobuf:"varint,4,opt,name=certificate_issued,json=certificateIssued,proto3" json:"certificate_issued,omitempty"` + ErrorMessage *string `protobuf:"bytes,5,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SendStatusUpdateRequest) Reset() { *x = SendStatusUpdateRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[14] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *SendStatusUpdateRequest) String() string { @@ -1257,8 +1424,8 @@ func (x *SendStatusUpdateRequest) String() string { func (*SendStatusUpdateRequest) ProtoMessage() {} func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[14] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[17] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1270,7 +1437,7 @@ func (x *SendStatusUpdateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateRequest.ProtoReflect.Descriptor instead. func (*SendStatusUpdateRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{14} + return file_proxy_service_proto_rawDescGZIP(), []int{17} } func (x *SendStatusUpdateRequest) GetServiceId() string { @@ -1310,18 +1477,16 @@ func (x *SendStatusUpdateRequest) GetErrorMessage() string { // SendStatusUpdateResponse is intentionally empty to allow for future expansion type SendStatusUpdateResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SendStatusUpdateResponse) Reset() { *x = SendStatusUpdateResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[15] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *SendStatusUpdateResponse) String() string { @@ -1331,8 +1496,8 @@ func (x *SendStatusUpdateResponse) String() string { func (*SendStatusUpdateResponse) ProtoMessage() {} func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[15] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[18] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1344,30 +1509,27 @@ func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateResponse.ProtoReflect.Descriptor instead. func (*SendStatusUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{15} + return file_proxy_service_proto_rawDescGZIP(), []int{18} } // CreateProxyPeerRequest is sent by the proxy to create a peer connection // The token is a one-time authentication token sent via ProxyMapping type CreateProxyPeerRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - ServiceId string `protobuf:"bytes,1,opt,name=service_id,json=serviceId,proto3" json:"service_id,omitempty"` - AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` - Token string `protobuf:"bytes,3,opt,name=token,proto3" json:"token,omitempty"` - WireguardPublicKey string `protobuf:"bytes,4,opt,name=wireguard_public_key,json=wireguardPublicKey,proto3" json:"wireguard_public_key,omitempty"` - Cluster string `protobuf:"bytes,5,opt,name=cluster,proto3" json:"cluster,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + ServiceId string `protobuf:"bytes,1,opt,name=service_id,json=serviceId,proto3" json:"service_id,omitempty"` + AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + Token string `protobuf:"bytes,3,opt,name=token,proto3" json:"token,omitempty"` + WireguardPublicKey string `protobuf:"bytes,4,opt,name=wireguard_public_key,json=wireguardPublicKey,proto3" json:"wireguard_public_key,omitempty"` + Cluster string `protobuf:"bytes,5,opt,name=cluster,proto3" json:"cluster,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *CreateProxyPeerRequest) Reset() { *x = CreateProxyPeerRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[16] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *CreateProxyPeerRequest) String() string { @@ -1377,8 +1539,8 @@ func (x *CreateProxyPeerRequest) String() string { func (*CreateProxyPeerRequest) ProtoMessage() {} func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[16] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[19] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1390,7 +1552,7 @@ func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerRequest.ProtoReflect.Descriptor instead. func (*CreateProxyPeerRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{16} + return file_proxy_service_proto_rawDescGZIP(), []int{19} } func (x *CreateProxyPeerRequest) GetServiceId() string { @@ -1430,21 +1592,18 @@ func (x *CreateProxyPeerRequest) GetCluster() string { // CreateProxyPeerResponse contains the result of peer creation type CreateProxyPeerResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + ErrorMessage *string `protobuf:"bytes,2,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` unknownFields protoimpl.UnknownFields - - Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` - ErrorMessage *string `protobuf:"bytes,2,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` + sizeCache protoimpl.SizeCache } func (x *CreateProxyPeerResponse) Reset() { *x = CreateProxyPeerResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[17] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *CreateProxyPeerResponse) String() string { @@ -1454,8 +1613,8 @@ func (x *CreateProxyPeerResponse) String() string { func (*CreateProxyPeerResponse) ProtoMessage() {} func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[17] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[20] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1467,7 +1626,7 @@ func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerResponse.ProtoReflect.Descriptor instead. func (*CreateProxyPeerResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{17} + return file_proxy_service_proto_rawDescGZIP(), []int{20} } func (x *CreateProxyPeerResponse) GetSuccess() bool { @@ -1485,22 +1644,19 @@ func (x *CreateProxyPeerResponse) GetErrorMessage() string { } type GetOIDCURLRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` + RedirectUrl string `protobuf:"bytes,3,opt,name=redirect_url,json=redirectUrl,proto3" json:"redirect_url,omitempty"` unknownFields protoimpl.UnknownFields - - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` - AccountId string `protobuf:"bytes,2,opt,name=account_id,json=accountId,proto3" json:"account_id,omitempty"` - RedirectUrl string `protobuf:"bytes,3,opt,name=redirect_url,json=redirectUrl,proto3" json:"redirect_url,omitempty"` + sizeCache protoimpl.SizeCache } func (x *GetOIDCURLRequest) Reset() { *x = GetOIDCURLRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[18] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[21] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetOIDCURLRequest) String() string { @@ -1510,8 +1666,8 @@ func (x *GetOIDCURLRequest) String() string { func (*GetOIDCURLRequest) ProtoMessage() {} func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[18] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[21] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1523,7 +1679,7 @@ func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLRequest.ProtoReflect.Descriptor instead. func (*GetOIDCURLRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{18} + return file_proxy_service_proto_rawDescGZIP(), []int{21} } func (x *GetOIDCURLRequest) GetId() string { @@ -1548,20 +1704,17 @@ func (x *GetOIDCURLRequest) GetRedirectUrl() string { } type GetOIDCURLResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` unknownFields protoimpl.UnknownFields - - Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` + sizeCache protoimpl.SizeCache } func (x *GetOIDCURLResponse) Reset() { *x = GetOIDCURLResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[19] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetOIDCURLResponse) String() string { @@ -1571,8 +1724,8 @@ func (x *GetOIDCURLResponse) String() string { func (*GetOIDCURLResponse) ProtoMessage() {} func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[19] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[22] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1584,7 +1737,7 @@ func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLResponse.ProtoReflect.Descriptor instead. func (*GetOIDCURLResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{19} + return file_proxy_service_proto_rawDescGZIP(), []int{22} } func (x *GetOIDCURLResponse) GetUrl() string { @@ -1595,21 +1748,18 @@ func (x *GetOIDCURLResponse) GetUrl() string { } type ValidateSessionRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` + SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` unknownFields protoimpl.UnknownFields - - Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` - SessionToken string `protobuf:"bytes,2,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ValidateSessionRequest) Reset() { *x = ValidateSessionRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[20] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ValidateSessionRequest) String() string { @@ -1619,8 +1769,8 @@ func (x *ValidateSessionRequest) String() string { func (*ValidateSessionRequest) ProtoMessage() {} func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[20] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[23] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1632,7 +1782,7 @@ func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionRequest.ProtoReflect.Descriptor instead. func (*ValidateSessionRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{20} + return file_proxy_service_proto_rawDescGZIP(), []int{23} } func (x *ValidateSessionRequest) GetDomain() string { @@ -1650,23 +1800,20 @@ func (x *ValidateSessionRequest) GetSessionToken() string { } type ValidateSessionResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Valid bool `protobuf:"varint,1,opt,name=valid,proto3" json:"valid,omitempty"` + UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + UserEmail string `protobuf:"bytes,3,opt,name=user_email,json=userEmail,proto3" json:"user_email,omitempty"` + DeniedReason string `protobuf:"bytes,4,opt,name=denied_reason,json=deniedReason,proto3" json:"denied_reason,omitempty"` unknownFields protoimpl.UnknownFields - - Valid bool `protobuf:"varint,1,opt,name=valid,proto3" json:"valid,omitempty"` - UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` - UserEmail string `protobuf:"bytes,3,opt,name=user_email,json=userEmail,proto3" json:"user_email,omitempty"` - DeniedReason string `protobuf:"bytes,4,opt,name=denied_reason,json=deniedReason,proto3" json:"denied_reason,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ValidateSessionResponse) Reset() { *x = ValidateSessionResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[21] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_proxy_service_proto_msgTypes[24] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ValidateSessionResponse) String() string { @@ -1676,8 +1823,8 @@ func (x *ValidateSessionResponse) String() string { func (*ValidateSessionResponse) ProtoMessage() {} func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[21] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_proxy_service_proto_msgTypes[24] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1689,7 +1836,7 @@ func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionResponse.ProtoReflect.Descriptor instead. func (*ValidateSessionResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{21} + return file_proxy_service_proto_rawDescGZIP(), []int{24} } func (x *ValidateSessionResponse) GetValid() bool { @@ -1722,317 +1869,193 @@ func (x *ValidateSessionResponse) GetDeniedReason() string { var File_proxy_service_proto protoreflect.FileDescriptor -var file_proxy_service_proto_rawDesc = []byte{ - 0x0a, 0x13, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x22, 0x66, 0x0a, 0x11, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, - 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x15, 0x73, 0x75, 0x70, 0x70, 0x6f, - 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x13, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, - 0x74, 0x73, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x6f, 0x72, 0x74, 0x73, 0x88, 0x01, 0x01, - 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, - 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x22, 0xe6, 0x01, 0x0a, 0x17, 0x47, - 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, - 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, - 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, - 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, - 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, - 0x69, 0x65, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, 0x70, - 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, - 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, - 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xce, 0x03, 0x0a, 0x11, 0x50, 0x61, 0x74, - 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x26, - 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, - 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, 0x73, - 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, 0x72, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x3e, 0x0a, 0x0c, 0x70, 0x61, - 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, - 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x0b, 0x70, - 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, 0x57, 0x0a, 0x0e, 0x63, 0x75, - 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, - 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, - 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x70, 0x72, 0x6f, - 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x4b, 0x0a, 0x14, 0x73, 0x65, - 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x6c, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, - 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, 0x65, - 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, - 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, - 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x72, 0x0a, 0x0b, 0x50, 0x61, 0x74, - 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, - 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x61, - 0x72, 0x67, 0x65, 0x74, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xaa, 0x01, - 0x0a, 0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x6b, 0x65, 0x79, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, - 0x79, 0x12, 0x35, 0x0a, 0x17, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, - 0x5f, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x14, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x41, 0x67, - 0x65, 0x53, 0x65, 0x63, 0x6f, 0x6e, 0x64, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, - 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, - 0x77, 0x6f, 0x72, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x6f, 0x69, 0x64, 0x63, 0x22, 0x95, 0x03, 0x0a, 0x0c, 0x50, - 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, 0x74, - 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, - 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, - 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, - 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, - 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2b, 0x0a, 0x04, 0x70, 0x61, - 0x74, 0x68, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, - 0x67, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x5f, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x75, 0x74, - 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x04, 0x61, 0x75, 0x74, 0x68, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x52, 0x04, 0x61, 0x75, 0x74, 0x68, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x61, 0x73, 0x73, 0x5f, 0x68, - 0x6f, 0x73, 0x74, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x0e, 0x70, 0x61, 0x73, 0x73, 0x48, 0x6f, 0x73, 0x74, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x12, 0x2b, 0x0a, 0x11, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x5f, 0x72, 0x65, 0x64, 0x69, - 0x72, 0x65, 0x63, 0x74, 0x73, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x65, 0x77, - 0x72, 0x69, 0x74, 0x65, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x73, 0x12, 0x12, 0x0a, - 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6d, 0x6f, 0x64, - 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, - 0x18, 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, - 0x72, 0x74, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, - 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, - 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, - 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x86, 0x04, 0x0a, - 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, - 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, - 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, - 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, - 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, - 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, - 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, - 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, - 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, - 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, - 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, - 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, - 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, - 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, - 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0xb6, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, - 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, - 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, - 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x39, 0x0a, 0x08, - 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x73, 0x73, - 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x08, 0x70, - 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, - 0x70, 0x69, 0x6e, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x2d, - 0x0a, 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, - 0x0a, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, - 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, - 0x14, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, - 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xf3, 0x01, 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, - 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, - 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, - 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, - 0x2d, 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, - 0x73, 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, - 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, - 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x1a, 0x0a, 0x18, 0x53, 0x65, - 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, - 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, - 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, - 0x72, 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x75, - 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, - 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, - 0x72, 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, - 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, - 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, - 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, - 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, - 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, - 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, - 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, - 0x6c, 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, - 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, - 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, - 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x8c, 0x01, 0x0a, 0x17, 0x56, 0x61, 0x6c, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, - 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, - 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, - 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, - 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, - 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, - 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, - 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, - 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, - 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, - 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, - 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, - 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, - 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, - 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, - 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, - 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, - 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, - 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, - 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, - 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, - 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, - 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, - 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, - 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, - 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, - 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, - 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, - 0x32, 0xfc, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, - 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, - 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, - 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, - 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, - 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, - 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, - 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, - 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, - 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, - 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, - 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, -} +const file_proxy_service_proto_rawDesc = "" + + "\n" + + "\x13proxy_service.proto\x12\n" + + "management\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"f\n" + + "\x11ProxyCapabilities\x127\n" + + "\x15supports_custom_ports\x18\x01 \x01(\bH\x00R\x13supportsCustomPorts\x88\x01\x01B\x18\n" + + "\x16_supports_custom_ports\"\xe6\x01\n" + + "\x17GetMappingUpdateRequest\x12\x19\n" + + "\bproxy_id\x18\x01 \x01(\tR\aproxyId\x12\x18\n" + + "\aversion\x18\x02 \x01(\tR\aversion\x129\n" + + "\n" + + "started_at\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\tstartedAt\x12\x18\n" + + "\aaddress\x18\x04 \x01(\tR\aaddress\x12A\n" + + "\fcapabilities\x18\x05 \x01(\v2\x1d.management.ProxyCapabilitiesR\fcapabilities\"\x82\x01\n" + + "\x18GetMappingUpdateResponse\x122\n" + + "\amapping\x18\x01 \x03(\v2\x18.management.ProxyMappingR\amapping\x122\n" + + "\x15initial_sync_complete\x18\x02 \x01(\bR\x13initialSyncComplete\"\xce\x03\n" + + "\x11PathTargetOptions\x12&\n" + + "\x0fskip_tls_verify\x18\x01 \x01(\bR\rskipTlsVerify\x12B\n" + + "\x0frequest_timeout\x18\x02 \x01(\v2\x19.google.protobuf.DurationR\x0erequestTimeout\x12>\n" + + "\fpath_rewrite\x18\x03 \x01(\x0e2\x1b.management.PathRewriteModeR\vpathRewrite\x12W\n" + + "\x0ecustom_headers\x18\x04 \x03(\v20.management.PathTargetOptions.CustomHeadersEntryR\rcustomHeaders\x12%\n" + + "\x0eproxy_protocol\x18\x05 \x01(\bR\rproxyProtocol\x12K\n" + + "\x14session_idle_timeout\x18\x06 \x01(\v2\x19.google.protobuf.DurationR\x12sessionIdleTimeout\x1a@\n" + + "\x12CustomHeadersEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"r\n" + + "\vPathMapping\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\x12\x16\n" + + "\x06target\x18\x02 \x01(\tR\x06target\x127\n" + + "\aoptions\x18\x03 \x01(\v2\x1d.management.PathTargetOptionsR\aoptions\"G\n" + + "\n" + + "HeaderAuth\x12\x16\n" + + "\x06header\x18\x01 \x01(\tR\x06header\x12!\n" + + "\fhashed_value\x18\x02 \x01(\tR\vhashedValue\"\xe5\x01\n" + + "\x0eAuthentication\x12\x1f\n" + + "\vsession_key\x18\x01 \x01(\tR\n" + + "sessionKey\x125\n" + + "\x17max_session_age_seconds\x18\x02 \x01(\x03R\x14maxSessionAgeSeconds\x12\x1a\n" + + "\bpassword\x18\x03 \x01(\bR\bpassword\x12\x10\n" + + "\x03pin\x18\x04 \x01(\bR\x03pin\x12\x12\n" + + "\x04oidc\x18\x05 \x01(\bR\x04oidc\x129\n" + + "\fheader_auths\x18\x06 \x03(\v2\x16.management.HeaderAuthR\vheaderAuths\"\xb8\x01\n" + + "\x12AccessRestrictions\x12#\n" + + "\rallowed_cidrs\x18\x01 \x03(\tR\fallowedCidrs\x12#\n" + + "\rblocked_cidrs\x18\x02 \x03(\tR\fblockedCidrs\x12+\n" + + "\x11allowed_countries\x18\x03 \x03(\tR\x10allowedCountries\x12+\n" + + "\x11blocked_countries\x18\x04 \x03(\tR\x10blockedCountries\"\xe6\x03\n" + + "\fProxyMapping\x126\n" + + "\x04type\x18\x01 \x01(\x0e2\".management.ProxyMappingUpdateTypeR\x04type\x12\x0e\n" + + "\x02id\x18\x02 \x01(\tR\x02id\x12\x1d\n" + + "\n" + + "account_id\x18\x03 \x01(\tR\taccountId\x12\x16\n" + + "\x06domain\x18\x04 \x01(\tR\x06domain\x12+\n" + + "\x04path\x18\x05 \x03(\v2\x17.management.PathMappingR\x04path\x12\x1d\n" + + "\n" + + "auth_token\x18\x06 \x01(\tR\tauthToken\x12.\n" + + "\x04auth\x18\a \x01(\v2\x1a.management.AuthenticationR\x04auth\x12(\n" + + "\x10pass_host_header\x18\b \x01(\bR\x0epassHostHeader\x12+\n" + + "\x11rewrite_redirects\x18\t \x01(\bR\x10rewriteRedirects\x12\x12\n" + + "\x04mode\x18\n" + + " \x01(\tR\x04mode\x12\x1f\n" + + "\vlisten_port\x18\v \x01(\x05R\n" + + "listenPort\x12O\n" + + "\x13access_restrictions\x18\f \x01(\v2\x1e.management.AccessRestrictionsR\x12accessRestrictions\"?\n" + + "\x14SendAccessLogRequest\x12'\n" + + "\x03log\x18\x01 \x01(\v2\x15.management.AccessLogR\x03log\"\x17\n" + + "\x15SendAccessLogResponse\"\x86\x04\n" + + "\tAccessLog\x128\n" + + "\ttimestamp\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x15\n" + + "\x06log_id\x18\x02 \x01(\tR\x05logId\x12\x1d\n" + + "\n" + + "account_id\x18\x03 \x01(\tR\taccountId\x12\x1d\n" + + "\n" + + "service_id\x18\x04 \x01(\tR\tserviceId\x12\x12\n" + + "\x04host\x18\x05 \x01(\tR\x04host\x12\x12\n" + + "\x04path\x18\x06 \x01(\tR\x04path\x12\x1f\n" + + "\vduration_ms\x18\a \x01(\x03R\n" + + "durationMs\x12\x16\n" + + "\x06method\x18\b \x01(\tR\x06method\x12#\n" + + "\rresponse_code\x18\t \x01(\x05R\fresponseCode\x12\x1b\n" + + "\tsource_ip\x18\n" + + " \x01(\tR\bsourceIp\x12%\n" + + "\x0eauth_mechanism\x18\v \x01(\tR\rauthMechanism\x12\x17\n" + + "\auser_id\x18\f \x01(\tR\x06userId\x12!\n" + + "\fauth_success\x18\r \x01(\bR\vauthSuccess\x12!\n" + + "\fbytes_upload\x18\x0e \x01(\x03R\vbytesUpload\x12%\n" + + "\x0ebytes_download\x18\x0f \x01(\x03R\rbytesDownload\x12\x1a\n" + + "\bprotocol\x18\x10 \x01(\tR\bprotocol\"\xf8\x01\n" + + "\x13AuthenticateRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n" + + "\n" + + "account_id\x18\x02 \x01(\tR\taccountId\x129\n" + + "\bpassword\x18\x03 \x01(\v2\x1b.management.PasswordRequestH\x00R\bpassword\x12*\n" + + "\x03pin\x18\x04 \x01(\v2\x16.management.PinRequestH\x00R\x03pin\x12@\n" + + "\vheader_auth\x18\x05 \x01(\v2\x1d.management.HeaderAuthRequestH\x00R\n" + + "headerAuthB\t\n" + + "\arequest\"W\n" + + "\x11HeaderAuthRequest\x12!\n" + + "\fheader_value\x18\x01 \x01(\tR\vheaderValue\x12\x1f\n" + + "\vheader_name\x18\x02 \x01(\tR\n" + + "headerName\"-\n" + + "\x0fPasswordRequest\x12\x1a\n" + + "\bpassword\x18\x01 \x01(\tR\bpassword\"\x1e\n" + + "\n" + + "PinRequest\x12\x10\n" + + "\x03pin\x18\x01 \x01(\tR\x03pin\"U\n" + + "\x14AuthenticateResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12#\n" + + "\rsession_token\x18\x02 \x01(\tR\fsessionToken\"\xf3\x01\n" + + "\x17SendStatusUpdateRequest\x12\x1d\n" + + "\n" + + "service_id\x18\x01 \x01(\tR\tserviceId\x12\x1d\n" + + "\n" + + "account_id\x18\x02 \x01(\tR\taccountId\x12/\n" + + "\x06status\x18\x03 \x01(\x0e2\x17.management.ProxyStatusR\x06status\x12-\n" + + "\x12certificate_issued\x18\x04 \x01(\bR\x11certificateIssued\x12(\n" + + "\rerror_message\x18\x05 \x01(\tH\x00R\ferrorMessage\x88\x01\x01B\x10\n" + + "\x0e_error_message\"\x1a\n" + + "\x18SendStatusUpdateResponse\"\xb8\x01\n" + + "\x16CreateProxyPeerRequest\x12\x1d\n" + + "\n" + + "service_id\x18\x01 \x01(\tR\tserviceId\x12\x1d\n" + + "\n" + + "account_id\x18\x02 \x01(\tR\taccountId\x12\x14\n" + + "\x05token\x18\x03 \x01(\tR\x05token\x120\n" + + "\x14wireguard_public_key\x18\x04 \x01(\tR\x12wireguardPublicKey\x12\x18\n" + + "\acluster\x18\x05 \x01(\tR\acluster\"o\n" + + "\x17CreateProxyPeerResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12(\n" + + "\rerror_message\x18\x02 \x01(\tH\x00R\ferrorMessage\x88\x01\x01B\x10\n" + + "\x0e_error_message\"e\n" + + "\x11GetOIDCURLRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n" + + "\n" + + "account_id\x18\x02 \x01(\tR\taccountId\x12!\n" + + "\fredirect_url\x18\x03 \x01(\tR\vredirectUrl\"&\n" + + "\x12GetOIDCURLResponse\x12\x10\n" + + "\x03url\x18\x01 \x01(\tR\x03url\"U\n" + + "\x16ValidateSessionRequest\x12\x16\n" + + "\x06domain\x18\x01 \x01(\tR\x06domain\x12#\n" + + "\rsession_token\x18\x02 \x01(\tR\fsessionToken\"\x8c\x01\n" + + "\x17ValidateSessionResponse\x12\x14\n" + + "\x05valid\x18\x01 \x01(\bR\x05valid\x12\x17\n" + + "\auser_id\x18\x02 \x01(\tR\x06userId\x12\x1d\n" + + "\n" + + "user_email\x18\x03 \x01(\tR\tuserEmail\x12#\n" + + "\rdenied_reason\x18\x04 \x01(\tR\fdeniedReason*d\n" + + "\x16ProxyMappingUpdateType\x12\x17\n" + + "\x13UPDATE_TYPE_CREATED\x10\x00\x12\x18\n" + + "\x14UPDATE_TYPE_MODIFIED\x10\x01\x12\x17\n" + + "\x13UPDATE_TYPE_REMOVED\x10\x02*F\n" + + "\x0fPathRewriteMode\x12\x18\n" + + "\x14PATH_REWRITE_DEFAULT\x10\x00\x12\x19\n" + + "\x15PATH_REWRITE_PRESERVE\x10\x01*\xc8\x01\n" + + "\vProxyStatus\x12\x18\n" + + "\x14PROXY_STATUS_PENDING\x10\x00\x12\x17\n" + + "\x13PROXY_STATUS_ACTIVE\x10\x01\x12#\n" + + "\x1fPROXY_STATUS_TUNNEL_NOT_CREATED\x10\x02\x12$\n" + + " PROXY_STATUS_CERTIFICATE_PENDING\x10\x03\x12#\n" + + "\x1fPROXY_STATUS_CERTIFICATE_FAILED\x10\x04\x12\x16\n" + + "\x12PROXY_STATUS_ERROR\x10\x052\xfc\x04\n" + + "\fProxyService\x12_\n" + + "\x10GetMappingUpdate\x12#.management.GetMappingUpdateRequest\x1a$.management.GetMappingUpdateResponse0\x01\x12T\n" + + "\rSendAccessLog\x12 .management.SendAccessLogRequest\x1a!.management.SendAccessLogResponse\x12Q\n" + + "\fAuthenticate\x12\x1f.management.AuthenticateRequest\x1a .management.AuthenticateResponse\x12]\n" + + "\x10SendStatusUpdate\x12#.management.SendStatusUpdateRequest\x1a$.management.SendStatusUpdateResponse\x12Z\n" + + "\x0fCreateProxyPeer\x12\".management.CreateProxyPeerRequest\x1a#.management.CreateProxyPeerResponse\x12K\n" + + "\n" + + "GetOIDCURL\x12\x1d.management.GetOIDCURLRequest\x1a\x1e.management.GetOIDCURLResponse\x12Z\n" + + "\x0fValidateSession\x12\".management.ValidateSessionRequest\x1a#.management.ValidateSessionResponseB\bZ\x06/protob\x06proto3" var ( file_proxy_service_proto_rawDescOnce sync.Once - file_proxy_service_proto_rawDescData = file_proxy_service_proto_rawDesc + file_proxy_service_proto_rawDescData []byte ) func file_proxy_service_proto_rawDescGZIP() []byte { file_proxy_service_proto_rawDescOnce.Do(func() { - file_proxy_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_proxy_service_proto_rawDescData) + file_proxy_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proxy_service_proto_rawDesc), len(file_proxy_service_proto_rawDesc))) }) return file_proxy_service_proto_rawDescData } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 23) -var file_proxy_service_proto_goTypes = []interface{}{ +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 26) +var file_proxy_service_proto_goTypes = []any{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType (PathRewriteMode)(0), // 1: management.PathRewriteMode (ProxyStatus)(0), // 2: management.ProxyStatus @@ -2041,63 +2064,69 @@ var file_proxy_service_proto_goTypes = []interface{}{ (*GetMappingUpdateResponse)(nil), // 5: management.GetMappingUpdateResponse (*PathTargetOptions)(nil), // 6: management.PathTargetOptions (*PathMapping)(nil), // 7: management.PathMapping - (*Authentication)(nil), // 8: management.Authentication - (*ProxyMapping)(nil), // 9: management.ProxyMapping - (*SendAccessLogRequest)(nil), // 10: management.SendAccessLogRequest - (*SendAccessLogResponse)(nil), // 11: management.SendAccessLogResponse - (*AccessLog)(nil), // 12: management.AccessLog - (*AuthenticateRequest)(nil), // 13: management.AuthenticateRequest - (*PasswordRequest)(nil), // 14: management.PasswordRequest - (*PinRequest)(nil), // 15: management.PinRequest - (*AuthenticateResponse)(nil), // 16: management.AuthenticateResponse - (*SendStatusUpdateRequest)(nil), // 17: management.SendStatusUpdateRequest - (*SendStatusUpdateResponse)(nil), // 18: management.SendStatusUpdateResponse - (*CreateProxyPeerRequest)(nil), // 19: management.CreateProxyPeerRequest - (*CreateProxyPeerResponse)(nil), // 20: management.CreateProxyPeerResponse - (*GetOIDCURLRequest)(nil), // 21: management.GetOIDCURLRequest - (*GetOIDCURLResponse)(nil), // 22: management.GetOIDCURLResponse - (*ValidateSessionRequest)(nil), // 23: management.ValidateSessionRequest - (*ValidateSessionResponse)(nil), // 24: management.ValidateSessionResponse - nil, // 25: management.PathTargetOptions.CustomHeadersEntry - (*timestamppb.Timestamp)(nil), // 26: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 27: google.protobuf.Duration + (*HeaderAuth)(nil), // 8: management.HeaderAuth + (*Authentication)(nil), // 9: management.Authentication + (*AccessRestrictions)(nil), // 10: management.AccessRestrictions + (*ProxyMapping)(nil), // 11: management.ProxyMapping + (*SendAccessLogRequest)(nil), // 12: management.SendAccessLogRequest + (*SendAccessLogResponse)(nil), // 13: management.SendAccessLogResponse + (*AccessLog)(nil), // 14: management.AccessLog + (*AuthenticateRequest)(nil), // 15: management.AuthenticateRequest + (*HeaderAuthRequest)(nil), // 16: management.HeaderAuthRequest + (*PasswordRequest)(nil), // 17: management.PasswordRequest + (*PinRequest)(nil), // 18: management.PinRequest + (*AuthenticateResponse)(nil), // 19: management.AuthenticateResponse + (*SendStatusUpdateRequest)(nil), // 20: management.SendStatusUpdateRequest + (*SendStatusUpdateResponse)(nil), // 21: management.SendStatusUpdateResponse + (*CreateProxyPeerRequest)(nil), // 22: management.CreateProxyPeerRequest + (*CreateProxyPeerResponse)(nil), // 23: management.CreateProxyPeerResponse + (*GetOIDCURLRequest)(nil), // 24: management.GetOIDCURLRequest + (*GetOIDCURLResponse)(nil), // 25: management.GetOIDCURLResponse + (*ValidateSessionRequest)(nil), // 26: management.ValidateSessionRequest + (*ValidateSessionResponse)(nil), // 27: management.ValidateSessionResponse + nil, // 28: management.PathTargetOptions.CustomHeadersEntry + (*timestamppb.Timestamp)(nil), // 29: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 30: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 26, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 29, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities - 9, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 27, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 11, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping + 30, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode - 25, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 27, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 28, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 30, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions - 0, // 8: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType - 7, // 9: management.ProxyMapping.path:type_name -> management.PathMapping - 8, // 10: management.ProxyMapping.auth:type_name -> management.Authentication - 12, // 11: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 26, // 12: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 14, // 13: management.AuthenticateRequest.password:type_name -> management.PasswordRequest - 15, // 14: management.AuthenticateRequest.pin:type_name -> management.PinRequest - 2, // 15: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 4, // 16: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 10, // 17: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 13, // 18: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 17, // 19: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 19, // 20: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 21, // 21: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 23, // 22: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 5, // 23: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 11, // 24: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 16, // 25: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 18, // 26: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 20, // 27: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 22, // 28: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 24, // 29: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 23, // [23:30] is the sub-list for method output_type - 16, // [16:23] is the sub-list for method input_type - 16, // [16:16] is the sub-list for extension type_name - 16, // [16:16] is the sub-list for extension extendee - 0, // [0:16] is the sub-list for field type_name + 8, // 8: management.Authentication.header_auths:type_name -> management.HeaderAuth + 0, // 9: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType + 7, // 10: management.ProxyMapping.path:type_name -> management.PathMapping + 9, // 11: management.ProxyMapping.auth:type_name -> management.Authentication + 10, // 12: management.ProxyMapping.access_restrictions:type_name -> management.AccessRestrictions + 14, // 13: management.SendAccessLogRequest.log:type_name -> management.AccessLog + 29, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 17, // 15: management.AuthenticateRequest.password:type_name -> management.PasswordRequest + 18, // 16: management.AuthenticateRequest.pin:type_name -> management.PinRequest + 16, // 17: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest + 2, // 18: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus + 4, // 19: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 12, // 20: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 15, // 21: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 20, // 22: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 22, // 23: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 24, // 24: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 26, // 25: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 5, // 26: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 13, // 27: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 19, // 28: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 21, // 29: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 23, // 30: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 25, // 31: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 27, // 32: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 26, // [26:33] is the sub-list for method output_type + 19, // [19:26] is the sub-list for method input_type + 19, // [19:19] is the sub-list for extension type_name + 19, // [19:19] is the sub-list for extension extendee + 0, // [0:19] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -2105,286 +2134,21 @@ func file_proxy_service_proto_init() { if File_proxy_service_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_proxy_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProxyCapabilities); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetMappingUpdateResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathTargetOptions); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PathMapping); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Authentication); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProxyMapping); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendAccessLogResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AccessLog); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PasswordRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PinRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*AuthenticateResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_proxy_service_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[10].OneofWrappers = []interface{}{ + file_proxy_service_proto_msgTypes[0].OneofWrappers = []any{} + file_proxy_service_proto_msgTypes[12].OneofWrappers = []any{ (*AuthenticateRequest_Password)(nil), (*AuthenticateRequest_Pin)(nil), + (*AuthenticateRequest_HeaderAuth)(nil), } - file_proxy_service_proto_msgTypes[14].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[17].OneofWrappers = []any{} + file_proxy_service_proto_msgTypes[20].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_proxy_service_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_service_proto_rawDesc), len(file_proxy_service_proto_rawDesc)), NumEnums: 3, - NumMessages: 23, + NumMessages: 26, NumExtensions: 0, NumServices: 1, }, @@ -2394,7 +2158,6 @@ func file_proxy_service_proto_init() { MessageInfos: file_proxy_service_proto_msgTypes, }.Build() File_proxy_service_proto = out.File - file_proxy_service_proto_rawDesc = nil file_proxy_service_proto_goTypes = nil file_proxy_service_proto_depIdxs = nil } diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index 457d12e85..2d7bed548 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -80,12 +80,27 @@ message PathMapping { PathTargetOptions options = 3; } +message HeaderAuth { + // Header name to check, e.g. "Authorization", "X-API-Key". + string header = 1; + // argon2id hash of the expected full header value. + string hashed_value = 2; +} + message Authentication { string session_key = 1; int64 max_session_age_seconds = 2; bool password = 3; bool pin = 4; bool oidc = 5; + repeated HeaderAuth header_auths = 6; +} + +message AccessRestrictions { + repeated string allowed_cidrs = 1; + repeated string blocked_cidrs = 2; + repeated string allowed_countries = 3; + repeated string blocked_countries = 4; } message ProxyMapping { @@ -106,6 +121,7 @@ message ProxyMapping { string mode = 10; // For L4/TLS: the port the proxy listens on. int32 listen_port = 11; + AccessRestrictions access_restrictions = 12; } // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. @@ -141,9 +157,15 @@ message AuthenticateRequest { oneof request { PasswordRequest password = 3; PinRequest pin = 4; + HeaderAuthRequest header_auth = 5; } } +message HeaderAuthRequest { + string header_value = 1; + string header_name = 2; +} + message PasswordRequest { string password = 1; } diff --git a/shared/relay/client/early_msg_buffer.go b/shared/relay/client/early_msg_buffer.go index 3ead94de1..52ff4d42e 100644 --- a/shared/relay/client/early_msg_buffer.go +++ b/shared/relay/client/early_msg_buffer.go @@ -65,8 +65,8 @@ func (b *earlyMsgBuffer) put(peerID messages.PeerID, msg Msg) bool { } entry := earlyMsg{ - peerID: peerID, - msg: msg, + peerID: peerID, + msg: msg, createdAt: time.Now(), } elem := b.order.PushBack(entry) From 80a8816b1dbb46d9dd3525f54abdb948ce04da66 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:00:23 +0100 Subject: [PATCH 23/27] [misc] Add image build after merge to main (#5605) --- .github/workflows/release.yml | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7ac5103d9..1a4676625 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -211,18 +211,36 @@ jobs: - name: Clean up GPG key if: always() run: rm -f /tmp/gpg-rpm-signing-key.asc - - name: Tag and push PR images (amd64 only) - if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository + - name: Tag and push images (amd64 only) + if: | + (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || + (github.event_name == 'push' && github.ref == 'refs/heads/main') run: | - PR_TAG="pr-${{ github.event.pull_request.number }}" + resolve_tags() { + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + echo "pr-${{ github.event.pull_request.number }}" + else + echo "main sha-$(git rev-parse --short HEAD)" + fi + } + + tag_and_push() { + local src="$1" img_name tag dst + img_name="${src%%:*}" + for tag in $(resolve_tags); do + dst="${img_name}:${tag}" + echo "Tagging ${src} -> ${dst}" + docker tag "$src" "$dst" + docker push "$dst" + done + } + + export -f tag_and_push resolve_tags + echo '${{ steps.goreleaser.outputs.artifacts }}' | \ jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \ grep '^ghcr.io/' | while read -r SRC; do - IMG_NAME="${SRC%%:*}" - DST="${IMG_NAME}:${PR_TAG}" - echo "Tagging ${SRC} -> ${DST}" - docker tag "$SRC" "$DST" - docker push "$DST" + tag_and_push "$SRC" done - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 From dff06d089874bc58b0bcdea1f54454083406f69a Mon Sep 17 00:00:00 2001 From: n0pashkov Date: Tue, 17 Mar 2026 07:33:13 +0300 Subject: [PATCH 24/27] [misc] Add netbird-tui to community projects (#5568) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bca81c20b..dc84af2fd 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how ### Community projects - [NetBird installer script](https://github.com/physk/netbird-installer) - [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/) +- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings **Note**: The `main` branch may be in an *unstable or even broken state* during development. For stable versions, see [releases](https://github.com/netbirdio/netbird/releases). From 59f5b34280c2adef6b836239115ffd9d7d293dbb Mon Sep 17 00:00:00 2001 From: tham-le <45093611+tham-le@users.noreply.github.com> Date: Tue, 17 Mar 2026 06:03:10 +0100 Subject: [PATCH 25/27] [client] add MTU option to embed.Options (#5550) Expose MTU configuration in the embed package so embedded clients can set the WireGuard tunnel MTU without the config file workaround. This is needed for protocols like QUIC that require larger datagrams than the default MTU of 1280. Validates MTU range via iface.ValidateMTU() at construction time to prevent invalid values from being persisted to config. Closes #5549 --- client/embed/embed.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/client/embed/embed.go b/client/embed/embed.go index 21043cf96..70013989a 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -14,6 +14,7 @@ import ( "github.com/sirupsen/logrus" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" @@ -81,6 +82,12 @@ type Options struct { BlockInbound bool // WireguardPort is the port for the WireGuard interface. Use 0 for a random port. WireguardPort *int + // MTU is the MTU for the WireGuard interface. + // Valid values are in the range 576..8192 bytes. + // If non-nil, this value overrides any value stored in the config file. + // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. + // Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams. + MTU *uint16 } // validateCredentials checks that exactly one credential type is provided @@ -112,6 +119,12 @@ func New(opts Options) (*Client, error) { return nil, err } + if opts.MTU != nil { + if err := iface.ValidateMTU(*opts.MTU); err != nil { + return nil, fmt.Errorf("invalid MTU: %w", err) + } + } + if opts.LogOutput != nil { logrus.SetOutput(opts.LogOutput) } @@ -151,6 +164,7 @@ func New(opts Options) (*Client, error) { DisableClientRoutes: &opts.DisableClientRoutes, BlockInbound: &opts.BlockInbound, WireguardPort: opts.WireguardPort, + MTU: opts.MTU, } if opts.ConfigPath != "" { config, err = profilemanager.UpdateOrCreateConfig(input) From 4e149c9222aaefccac1fe7cbe841e946c3b76ea4 Mon Sep 17 00:00:00 2001 From: Wesley Gimenes Date: Tue, 17 Mar 2026 02:09:12 -0300 Subject: [PATCH 26/27] [client] update gvisor to build with Go 1.26.x (#5447) Building the client with Go 1.26.x fails with errors: ``` [...] /builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go126.go:22:2: WaitReasonSelect redeclared in this block /builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go125.go:22:2: other declaration of WaitReasonSelect /builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go126.go:23:2: WaitReasonChanReceive redeclared in this block /builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go125.go:23:2: other declaration of WaitReasonChanReceive /builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go126.go:24:2: WaitReasonSemacquire redeclared in this block /builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go125.go:24:2: other declaration of WaitReasonSemacquire [...] ``` Fixes: https://github.com/netbirdio/netbird/issues/5290 ("Does not build with Go 1.26rc3") Signed-off-by: Wesley Gimenes --- go.mod | 6 ++---- go.sum | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 4bcdbdc78..f8b27aca0 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/netbirdio/netbird -go 1.25 - -toolchain go1.25.5 +go 1.25.5 require ( cunicu.li/go-rosenpass v0.4.0 @@ -125,7 +123,7 @@ require ( gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 - gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c + gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 ) require ( diff --git a/go.sum b/go.sum index 1bd9396bb..d4b42eab4 100644 --- a/go.sum +++ b/go.sum @@ -852,5 +852,5 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA= -gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ= +gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA= +gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= From a590c38d8b1f4429535f49ff5cac9655a29483a7 Mon Sep 17 00:00:00 2001 From: eason <85663565+mango766@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:27:47 +0800 Subject: [PATCH 27/27] [client] Fix IPv6 address formatting in DNS address construction (#5603) Replace fmt.Sprintf("%s:%d", ip, port) with net.JoinHostPort() to properly handle IPv6 addresses that need bracket wrapping (e.g., [2606:4700:4700::1111]:53 instead of 2606:4700:4700::1111:53). Without this fix, configuring IPv6 nameservers causes "too many colons in address" errors because Go's net.Dial cannot parse the malformed address string. Fixes #5601 Related to #4074 Co-authored-by: easonysliu --- client/internal/dns/service_listener.go | 5 +++-- client/internal/routemanager/client/client.go | 4 +++- client/internal/routemanager/dnsinterceptor/handler.go | 4 +++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 806559444..f7ddfd40f 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "runtime" + "strconv" "sync" "time" @@ -69,7 +70,7 @@ func (s *serviceViaListener) Listen() error { return fmt.Errorf("eval listen address: %w", err) } s.listenIP = s.listenIP.Unmap() - s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort) + s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort))) log.Debugf("starting dns on %s", s.server.Addr) go func() { s.setListenerStatus(true) @@ -186,7 +187,7 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) { } func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool { - addrString := fmt.Sprintf("%s:%d", ip, port) + addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port)) udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) probeListener, err := net.ListenUDP("udp", udpAddr) if err != nil { diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index bad616271..e6ef8b876 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -3,7 +3,9 @@ package client import ( "context" "fmt" + "net" "reflect" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -564,7 +566,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler { return dnsinterceptor.New(params) case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(params.WgInterface) - dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()) + dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort())) return dynamic.NewRoute(params, dnsAddr) default: return static.NewRoute(params) diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 4bf0d5476..64f2a8789 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "net" "net/netip" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -249,7 +251,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) + upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10)) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel()
Account IDDomainsServices Age Status
{{.AccountID}}{{.Domains}}{{.Services}} {{.Age}} {{.Status}}