diff --git a/client/internal/engine.go b/client/internal/engine.go index 4f3cf0998..f2d724aa4 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -28,8 +28,8 @@ import ( "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" - nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/device" + nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/debug" @@ -1562,8 +1562,10 @@ func (e *Engine) receiveSignalEvents() { defer e.shutdownWg.Done() // connect to a stream of messages coming from the signal server err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { + start := time.Now() e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + gotLock := time.Since(start) // Check context INSIDE lock to ensure atomicity with shutdown if e.ctx.Err() != nil { @@ -1587,6 +1589,8 @@ func (e *Engine) receiveSignalEvents() { return err } + log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID) + if msg.Body.Type == sProto.Body_OFFER { conn.OnRemoteOffer(*offerAnswer) } else { diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 12c9ff4af..4bf0d5476 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -351,6 +351,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log. logger.Errorf("failed to update domain prefixes: %v", err) } + // Allow time for route changes to be applied before sending + // the DNS response (relevant on iOS where setTunnelNetworkSettings + // is asynchronous). + waitForRouteSettlement(logger) + d.replaceIPsInDNSResponse(r, newPrefixes, logger) } } diff --git a/client/internal/routemanager/dnsinterceptor/handler_ios.go b/client/internal/routemanager/dnsinterceptor/handler_ios.go new file mode 100644 index 000000000..4cf80eb16 --- /dev/null +++ b/client/internal/routemanager/dnsinterceptor/handler_ios.go @@ -0,0 +1,20 @@ +//go:build ios + +package dnsinterceptor + +import ( + "time" + + log "github.com/sirupsen/logrus" +) + +const routeSettleDelay = 500 * time.Millisecond + +// waitForRouteSettlement introduces a short delay on iOS to allow +// setTunnelNetworkSettings to apply route changes before the DNS +// response reaches the application. Without this, the first request +// to a newly resolved domain may bypass the tunnel. +func waitForRouteSettlement(logger *log.Entry) { + logger.Tracef("waiting %v for iOS route settlement", routeSettleDelay) + time.Sleep(routeSettleDelay) +} diff --git a/client/internal/routemanager/dnsinterceptor/handler_nonios.go b/client/internal/routemanager/dnsinterceptor/handler_nonios.go new file mode 100644 index 000000000..68cd7330b --- /dev/null +++ b/client/internal/routemanager/dnsinterceptor/handler_nonios.go @@ -0,0 +1,12 @@ +//go:build !ios + +package dnsinterceptor + +import log "github.com/sirupsen/logrus" + +func waitForRouteSettlement(_ *log.Entry) { + // No-op on non-iOS platforms: route changes are applied synchronously by + // the kernel, so no settlement delay is needed before the DNS response + // reaches the application. The delay is only required on iOS where + // setTunnelNetworkSettings applies routes asynchronously. +} diff --git a/formatter/txt/formatter.go b/formatter/txt/formatter.go index 3b2a3fb4d..4f174a740 100644 --- a/formatter/txt/formatter.go +++ b/formatter/txt/formatter.go @@ -1,8 +1,6 @@ package txt import ( - "time" - "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/formatter/levels" @@ -18,7 +16,7 @@ type TextFormatter struct { func NewTextFormatter() *TextFormatter { return &TextFormatter{ levelDesc: levels.ValidLevelDesc, - timestampFormat: time.RFC3339, // or RFC3339 + timestampFormat: "2006-01-02T15:04:05.000Z07:00", } } diff --git a/formatter/txt/formatter_test.go b/formatter/txt/formatter_test.go index 590af5d50..1b20a3ebf 100644 --- a/formatter/txt/formatter_test.go +++ b/formatter/txt/formatter_test.go @@ -21,6 +21,6 @@ func TestLogTextFormat(t *testing.T) { result, _ := formatter.Format(someEntry) parsedString := string(result) - expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$" + expectedString := "^2021-02-21T01:10:30.000Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$" assert.Regexp(t, expectedString, parsedString) } diff --git a/management/internals/modules/reverseproxy/accesslogs/filter.go b/management/internals/modules/reverseproxy/accesslogs/filter.go index f4b0a2048..a1fa28312 100644 --- a/management/internals/modules/reverseproxy/accesslogs/filter.go +++ b/management/internals/modules/reverseproxy/accesslogs/filter.go @@ -3,6 +3,7 @@ package accesslogs import ( "net/http" "strconv" + "strings" "time" ) @@ -11,15 +12,39 @@ const ( DefaultPageSize = 50 // MaxPageSize is the maximum number of records allowed per page MaxPageSize = 100 + + // Default sorting + DefaultSortBy = "timestamp" + DefaultSortOrder = "desc" ) -// AccessLogFilter holds pagination and filtering parameters for access logs +// Valid sortable fields mapped to their database column names or expressions +// For multi-column sorts, columns are separated by comma (e.g., "host, path") +var validSortFields = map[string]string{ + "timestamp": "timestamp", + "url": "host, path", // Sort by host first, then path + "host": "host", + "path": "path", + "method": "method", + "status_code": "status_code", + "duration": "duration", + "source_ip": "location_connection_ip", + "user_id": "user_id", + "auth_method": "auth_method_used", + "reason": "reason", +} + +// AccessLogFilter holds pagination, filtering, and sorting parameters for access logs type AccessLogFilter struct { // Page is the current page number (1-indexed) Page int // PageSize is the number of records per page PageSize int + // Sorting parameters + SortBy string // Field to sort by: timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason + SortOrder string // Sort order: asc or desc (default: desc) + // Filtering parameters Search *string // General search across log ID, host, path, source IP, and user fields SourceIP *string // Filter by source IP address @@ -35,13 +60,16 @@ type AccessLogFilter struct { EndDate *time.Time // Filter by timestamp <= end_date } -// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters +// ParseFromRequest parses pagination, sorting, and filter parameters from HTTP request query parameters func (f *AccessLogFilter) ParseFromRequest(r *http.Request) { queryParams := r.URL.Query() f.Page = parsePositiveInt(queryParams.Get("page"), 1) f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize) + f.SortBy = parseSortField(queryParams.Get("sort_by")) + f.SortOrder = parseSortOrder(queryParams.Get("sort_order")) + f.Search = parseOptionalString(queryParams.Get("search")) f.SourceIP = parseOptionalString(queryParams.Get("source_ip")) f.Host = parseOptionalString(queryParams.Get("host")) @@ -107,3 +135,44 @@ func (f *AccessLogFilter) GetOffset() int { func (f *AccessLogFilter) GetLimit() int { return f.PageSize } + +// GetSortColumn returns the validated database column name for sorting +func (f *AccessLogFilter) GetSortColumn() string { + if column, ok := validSortFields[f.SortBy]; ok { + return column + } + return validSortFields[DefaultSortBy] +} + +// GetSortOrder returns the validated sort order (ASC or DESC) +func (f *AccessLogFilter) GetSortOrder() string { + if f.SortOrder == "asc" || f.SortOrder == "desc" { + return f.SortOrder + } + return DefaultSortOrder +} + +// parseSortField validates and returns the sort field, defaulting if invalid +func parseSortField(s string) string { + if s == "" { + return DefaultSortBy + } + // Check if the field is valid + if _, ok := validSortFields[s]; ok { + return s + } + return DefaultSortBy +} + +// parseSortOrder validates and returns the sort order, defaulting if invalid +func parseSortOrder(s string) string { + if s == "" { + return DefaultSortOrder + } + // Normalize to lowercase + s = strings.ToLower(s) + if s == "asc" || s == "desc" { + return s + } + return DefaultSortOrder +} diff --git a/management/internals/modules/reverseproxy/accesslogs/filter_test.go b/management/internals/modules/reverseproxy/accesslogs/filter_test.go index 5d48ea9d2..ea1fce54b 100644 --- a/management/internals/modules/reverseproxy/accesslogs/filter_test.go +++ b/management/internals/modules/reverseproxy/accesslogs/filter_test.go @@ -361,6 +361,205 @@ func TestParseOptionalRFC3339(t *testing.T) { } } +func TestAccessLogFilter_SortingDefaults(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, DefaultSortBy, filter.SortBy, "SortBy should default to timestamp") + assert.Equal(t, DefaultSortOrder, filter.SortOrder, "SortOrder should default to desc") + assert.Equal(t, "timestamp", filter.GetSortColumn(), "GetSortColumn should return timestamp") + assert.Equal(t, "desc", filter.GetSortOrder(), "GetSortOrder should return desc") +} + +func TestAccessLogFilter_ValidSortFields(t *testing.T) { + tests := []struct { + name string + sortBy string + expectedColumn string + expectedSortByVal string + }{ + {"timestamp", "timestamp", "timestamp", "timestamp"}, + {"url", "url", "host, path", "url"}, + {"host", "host", "host", "host"}, + {"path", "path", "path", "path"}, + {"method", "method", "method", "method"}, + {"status_code", "status_code", "status_code", "status_code"}, + {"duration", "duration", "duration", "duration"}, + {"source_ip", "source_ip", "location_connection_ip", "source_ip"}, + {"user_id", "user_id", "user_id", "user_id"}, + {"auth_method", "auth_method", "auth_method_used", "auth_method"}, + {"reason", "reason", "reason", "reason"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy, nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, tt.expectedSortByVal, filter.SortBy, "SortBy mismatch") + assert.Equal(t, tt.expectedColumn, filter.GetSortColumn(), "GetSortColumn mismatch") + }) + } +} + +func TestAccessLogFilter_InvalidSortField(t *testing.T) { + tests := []struct { + name string + sortBy string + expected string + }{ + {"invalid field", "invalid_field", DefaultSortBy}, + {"empty field", "", DefaultSortBy}, + {"malicious input", "timestamp--DROP", DefaultSortBy}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + q := req.URL.Query() + q.Set("sort_by", tt.sortBy) + req.URL.RawQuery = q.Encode() + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, tt.expected, filter.SortBy, "Invalid sort field should default to timestamp") + assert.Equal(t, validSortFields[DefaultSortBy], filter.GetSortColumn()) + }) + } +} + +func TestAccessLogFilter_SortOrder(t *testing.T) { + tests := []struct { + name string + sortOrder string + expected string + }{ + {"ascending", "asc", "asc"}, + {"descending", "desc", "desc"}, + {"uppercase ASC", "ASC", "asc"}, + {"uppercase DESC", "DESC", "desc"}, + {"mixed case Asc", "Asc", "asc"}, + {"invalid order", "invalid", DefaultSortOrder}, + {"empty order", "", DefaultSortOrder}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test?sort_order="+tt.sortOrder, nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, tt.expected, filter.GetSortOrder(), "GetSortOrder mismatch") + }) + } +} + +func TestAccessLogFilter_CompleteSortingScenarios(t *testing.T) { + tests := []struct { + name string + sortBy string + sortOrder string + expectedColumn string + expectedOrder string + }{ + { + name: "sort by host ascending", + sortBy: "host", + sortOrder: "asc", + expectedColumn: "host", + expectedOrder: "asc", + }, + { + name: "sort by duration descending", + sortBy: "duration", + sortOrder: "desc", + expectedColumn: "duration", + expectedOrder: "desc", + }, + { + name: "sort by status_code ascending", + sortBy: "status_code", + sortOrder: "asc", + expectedColumn: "status_code", + expectedOrder: "asc", + }, + { + name: "invalid sort with valid order", + sortBy: "invalid", + sortOrder: "asc", + expectedColumn: "timestamp", + expectedOrder: "asc", + }, + { + name: "valid sort with invalid order", + sortBy: "method", + sortOrder: "invalid", + expectedColumn: "method", + expectedOrder: DefaultSortOrder, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy+"&sort_order="+tt.sortOrder, nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, tt.expectedColumn, filter.GetSortColumn()) + assert.Equal(t, tt.expectedOrder, filter.GetSortOrder()) + }) + } +} + +func TestParseSortField(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"valid field", "host", "host"}, + {"empty string", "", DefaultSortBy}, + {"invalid field", "invalid", DefaultSortBy}, + {"malicious input", "timestamp--DROP", DefaultSortBy}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseSortField(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseSortOrder(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"asc lowercase", "asc", "asc"}, + {"desc lowercase", "desc", "desc"}, + {"ASC uppercase", "ASC", "asc"}, + {"DESC uppercase", "DESC", "desc"}, + {"invalid", "invalid", DefaultSortOrder}, + {"empty", "", DefaultSortOrder}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseSortOrder(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + // Helper functions for creating pointers func strPtr(s string) *string { return &s diff --git a/management/internals/modules/reverseproxy/accesslogs/interface.go b/management/internals/modules/reverseproxy/accesslogs/interface.go index 1c51a8a7d..04f096bf1 100644 --- a/management/internals/modules/reverseproxy/accesslogs/interface.go +++ b/management/internals/modules/reverseproxy/accesslogs/interface.go @@ -7,4 +7,7 @@ import ( type Manager interface { SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error) + CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) + StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) + StopPeriodicCleanup() } diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go index 7bcdecb1b..e7fba7bed 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go @@ -3,6 +3,7 @@ package manager import ( "context" "strings" + "time" log "github.com/sirupsen/logrus" @@ -19,6 +20,7 @@ type managerImpl struct { store store.Store permissionsManager permissions.Manager geo geolocation.Geolocation + cleanupCancel context.CancelFunc } func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager { @@ -78,6 +80,74 @@ func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID st return logs, totalCount, nil } +// CleanupOldAccessLogs deletes access logs older than the specified retention period +func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) { + if retentionDays <= 0 { + log.WithContext(ctx).Debug("access log cleanup skipped: retention days is 0 or negative") + return 0, nil + } + + cutoffTime := time.Now().AddDate(0, 0, -retentionDays) + deletedCount, err := m.store.DeleteOldAccessLogs(ctx, cutoffTime) + if err != nil { + log.WithContext(ctx).Errorf("failed to cleanup old access logs: %v", err) + return 0, err + } + + if deletedCount > 0 { + log.WithContext(ctx).Infof("cleaned up %d access logs older than %d days", deletedCount, retentionDays) + } + + return deletedCount, nil +} + +// StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs +func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) { + if retentionDays <= 0 { + log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative") + return + } + + if cleanupIntervalHours <= 0 { + cleanupIntervalHours = 24 + } + + cleanupCtx, cancel := context.WithCancel(ctx) + m.cleanupCancel = cancel + + cleanupInterval := time.Duration(cleanupIntervalHours) * time.Hour + ticker := time.NewTicker(cleanupInterval) + + go func() { + defer ticker.Stop() + + // Run cleanup immediately on startup + log.WithContext(cleanupCtx).Infof("starting access log cleanup routine (retention: %d days, interval: %d hours)", retentionDays, cleanupIntervalHours) + if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil { + log.WithContext(cleanupCtx).Errorf("initial access log cleanup failed: %v", err) + } + + for { + select { + case <-cleanupCtx.Done(): + log.WithContext(cleanupCtx).Info("stopping access log cleanup routine") + return + case <-ticker.C: + if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil { + log.WithContext(cleanupCtx).Errorf("periodic access log cleanup failed: %v", err) + } + } + } + }() +} + +// StopPeriodicCleanup stops the periodic cleanup routine +func (m *managerImpl) StopPeriodicCleanup() { + if m.cleanupCancel != nil { + m.cleanupCancel() + } +} + // resolveUserFilters converts user email/name filters to user ID filter func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error { if filter.UserEmail == nil && filter.UserName == nil { diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go new file mode 100644 index 000000000..8fadef85f --- /dev/null +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go @@ -0,0 +1,281 @@ +package manager + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/store" +) + +func TestCleanupOldAccessLogs(t *testing.T) { + tests := []struct { + name string + retentionDays int + setupMock func(*store.MockStore) + expectedCount int64 + expectedError bool + }{ + { + name: "cleanup logs older than retention period", + retentionDays: 30, + setupMock: func(mockStore *store.MockStore) { + mockStore.EXPECT(). + DeleteOldAccessLogs(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) { + expectedCutoff := time.Now().AddDate(0, 0, -30) + timeDiff := olderThan.Sub(expectedCutoff) + if timeDiff.Abs() > time.Second { + t.Errorf("cutoff time not as expected: got %v, want ~%v", olderThan, expectedCutoff) + } + return 5, nil + }) + }, + expectedCount: 5, + expectedError: false, + }, + { + name: "no logs to cleanup", + retentionDays: 30, + setupMock: func(mockStore *store.MockStore) { + mockStore.EXPECT(). + DeleteOldAccessLogs(gomock.Any(), gomock.Any()). + Return(int64(0), nil) + }, + expectedCount: 0, + expectedError: false, + }, + { + name: "zero retention days skips cleanup", + retentionDays: 0, + setupMock: func(mockStore *store.MockStore) { + // No expectations - DeleteOldAccessLogs should not be called + }, + expectedCount: 0, + expectedError: false, + }, + { + name: "negative retention days skips cleanup", + retentionDays: -10, + setupMock: func(mockStore *store.MockStore) { + // No expectations - DeleteOldAccessLogs should not be called + }, + expectedCount: 0, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + tt.setupMock(mockStore) + + manager := &managerImpl{ + store: mockStore, + } + + ctx := context.Background() + deletedCount, err := manager.CleanupOldAccessLogs(ctx, tt.retentionDays) + + if tt.expectedError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.expectedCount, deletedCount, "unexpected number of deleted logs") + }) + } +} + +func TestCleanupWithExactBoundary(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + mockStore.EXPECT(). + DeleteOldAccessLogs(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) { + expectedCutoff := time.Now().AddDate(0, 0, -30) + timeDiff := olderThan.Sub(expectedCutoff) + assert.Less(t, timeDiff.Abs(), time.Second, "cutoff time should be close to expected value") + return 1, nil + }) + + manager := &managerImpl{ + store: mockStore, + } + + ctx := context.Background() + deletedCount, err := manager.CleanupOldAccessLogs(ctx, 30) + + require.NoError(t, err) + assert.Equal(t, int64(1), deletedCount) +} + +func TestStartPeriodicCleanup(t *testing.T) { + t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + // No expectations - cleanup should not run + + manager := &managerImpl{ + store: mockStore, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + manager.StartPeriodicCleanup(ctx, 0, 1) + + time.Sleep(100 * time.Millisecond) + + // If DeleteOldAccessLogs was called, the test will fail due to unexpected call + }) + + t.Run("periodic cleanup runs immediately on start", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + mockStore.EXPECT(). + DeleteOldAccessLogs(gomock.Any(), gomock.Any()). + Return(int64(2), nil). + Times(1) + + manager := &managerImpl{ + store: mockStore, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + manager.StartPeriodicCleanup(ctx, 30, 24) + + time.Sleep(200 * time.Millisecond) + + // Expectations verified by gomock on defer ctrl.Finish() + }) + + t.Run("periodic cleanup stops on context cancel", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + mockStore.EXPECT(). + DeleteOldAccessLogs(gomock.Any(), gomock.Any()). + Return(int64(1), nil). + Times(1) + + manager := &managerImpl{ + store: mockStore, + } + + ctx, cancel := context.WithCancel(context.Background()) + + manager.StartPeriodicCleanup(ctx, 30, 24) + + time.Sleep(100 * time.Millisecond) + + cancel() + + time.Sleep(200 * time.Millisecond) + + }) + + t.Run("cleanup interval defaults to 24 hours when invalid", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + mockStore.EXPECT(). + DeleteOldAccessLogs(gomock.Any(), gomock.Any()). + Return(int64(0), nil). + Times(1) + + manager := &managerImpl{ + store: mockStore, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + manager.StartPeriodicCleanup(ctx, 30, 0) + + time.Sleep(100 * time.Millisecond) + + manager.StopPeriodicCleanup() + }) + + t.Run("cleanup interval uses configured hours", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + mockStore.EXPECT(). + DeleteOldAccessLogs(gomock.Any(), gomock.Any()). + Return(int64(3), nil). + Times(1) + + manager := &managerImpl{ + store: mockStore, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + manager.StartPeriodicCleanup(ctx, 30, 12) + + time.Sleep(100 * time.Millisecond) + + manager.StopPeriodicCleanup() + }) +} + +func TestStopPeriodicCleanup(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + mockStore.EXPECT(). + DeleteOldAccessLogs(gomock.Any(), gomock.Any()). + Return(int64(1), nil). + Times(1) + + manager := &managerImpl{ + store: mockStore, + } + + ctx := context.Background() + + manager.StartPeriodicCleanup(ctx, 30, 24) + + time.Sleep(100 * time.Millisecond) + + manager.StopPeriodicCleanup() + + time.Sleep(200 * time.Millisecond) + + // Expectations verified by gomock - would fail if more than 1 call happened +} + +func TestStopPeriodicCleanup_NotStarted(t *testing.T) { + manager := &managerImpl{} + + // Should not panic if cleanup was never started + manager.StopPeriodicCleanup() +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 5e5b1622a..09fc2bd94 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -200,6 +200,11 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { func (s *BaseServer) AccessLogsManager() accesslogs.Manager { return Create(s, func() accesslogs.Manager { accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager()) + accessLogManager.StartPeriodicCleanup( + context.Background(), + s.Config.ReverseProxy.AccessLogRetentionDays, + s.Config.ReverseProxy.AccessLogCleanupIntervalHours, + ) return accessLogManager }) } diff --git a/management/internals/server/config/config.go b/management/internals/server/config/config.go index 5ed1c3ede..0ba393263 100644 --- a/management/internals/server/config/config.go +++ b/management/internals/server/config/config.go @@ -200,4 +200,13 @@ type ReverseProxy struct { // request headers if the peer's address falls within one of these // trusted IP prefixes. TrustedPeers []netip.Prefix + + // AccessLogRetentionDays specifies the number of days to retain access logs. + // Logs older than this duration will be automatically deleted during cleanup. + // A value of 0 or negative means logs are kept indefinitely (no cleanup). + AccessLogRetentionDays int + + // AccessLogCleanupIntervalHours specifies how often (in hours) to run the cleanup routine. + // Defaults to 24 hours if not set or set to 0. + AccessLogCleanupIntervalHours int } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index e1e6b5680..f92e575d0 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -157,6 +157,18 @@ type testSetup struct { // testAccessLogManager is a minimal mock for accesslogs.Manager. type testAccessLogManager struct{} +func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) { + return 0, nil +} + +func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) { + return +} + +func (m *testAccessLogManager) StopPeriodicCleanup() { + return +} + func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error { return nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index c1a67b186..7c70dafcb 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5083,8 +5083,20 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin query = s.applyAccessLogFilters(query, filter) + sortColumns := filter.GetSortColumn() + sortOrder := strings.ToUpper(filter.GetSortOrder()) + + var orderClauses []string + for _, col := range strings.Split(sortColumns, ",") { + col = strings.TrimSpace(col) + if col != "" { + orderClauses = append(orderClauses, col+" "+sortOrder) + } + } + orderClause := strings.Join(orderClauses, ", ") + query = query. - Order("timestamp DESC"). + Order(orderClause). Limit(filter.GetLimit()). Offset(filter.GetOffset()) @@ -5101,6 +5113,20 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin return logs, totalCount, nil } +// DeleteOldAccessLogs deletes all access logs older than the specified time +func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) { + result := s.db.WithContext(ctx). + Where("timestamp < ?", olderThan). + Delete(&accesslogs.AccessLogEntry{}) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete old access logs: %v", result.Error) + return 0, status.Errorf(status.Internal, "failed to delete old access logs") + } + + return result.RowsAffected, nil +} + // applyAccessLogFilters applies filter conditions to the query func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.AccessLogFilter) *gorm.DB { if filter.Search != nil { diff --git a/management/server/store/store.go b/management/server/store/store.go index cc9c19be6..1c7179b58 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -270,6 +270,7 @@ type Store interface { CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) + DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 94d14398c..1a7ff5d53 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -475,6 +475,21 @@ func (mr *MockStoreMockRecorder) DeleteNetworkRouter(ctx, accountID, routerID in return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteNetworkRouter", reflect.TypeOf((*MockStore)(nil).DeleteNetworkRouter), ctx, accountID, routerID) } +// DeleteOldAccessLogs mocks base method. +func (m *MockStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteOldAccessLogs", ctx, olderThan) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteOldAccessLogs indicates an expected call of DeleteOldAccessLogs. +func (mr *MockStoreMockRecorder) DeleteOldAccessLogs(ctx, olderThan interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAccessLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAccessLogs), ctx, olderThan) +} + // DeletePAT mocks base method. func (m *MockStore) DeletePAT(ctx context.Context, userID, patID string) error { m.ctrl.T.Helper() diff --git a/management/server/user.go b/management/server/user.go index 48005f325..924efc1e4 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -737,6 +737,14 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } + if initiatorUserId != activity.SystemInitiator { + freshInitiator, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, initiatorUserId) + if err != nil { + return false, nil, nil, nil, fmt.Errorf("failed to re-read initiator user in transaction: %w", err) + } + initiatorUser = freshInitiator + } + oldUser, isNewUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists) if err != nil { return false, nil, nil, nil, err @@ -864,7 +872,10 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse return nil } - // @todo double check these + if !initiatorUser.HasAdminPower() { + return status.Errorf(status.PermissionDenied, "only admins and owners can update users") + } + if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } diff --git a/management/server/user_test.go b/management/server/user_test.go index 2dd1cea2e..72a19a9a5 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -2031,3 +2031,87 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) { t.Logf("Duplicate email error: %v", err) }) } + +func TestValidateUserUpdate_RejectsNonAdminInitiator(t *testing.T) { + groupsMap := map[string]*types.Group{} + + initiator := &types.User{ + Id: "initiator", + Role: types.UserRoleUser, + } + oldUser := &types.User{ + Id: "target", + Role: types.UserRoleUser, + } + update := &types.User{ + Id: "target", + Role: types.UserRoleOwner, + } + + err := validateUserUpdate(groupsMap, initiator, oldUser, update) + require.Error(t, err, "regular user should not be able to promote to owner") + assert.Contains(t, err.Error(), "only admins and owners can update users") +} + +func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) { + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + account := newAccountWithId(context.Background(), "account1", "owner1", "", "", "", false) + + adminID := "admin1" + account.Users[adminID] = types.NewAdminUser(adminID) + + targetID := "target1" + account.Users[targetID] = types.NewRegularUser(targetID, "", "") + + require.NoError(t, s.SaveAccount(context.Background(), account)) + + demotedAdmin, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, adminID) + require.NoError(t, err) + demotedAdmin.Role = types.UserRoleUser + require.NoError(t, s.SaveUser(context.Background(), demotedAdmin)) + + staleInitiator := &types.User{ + Id: adminID, + AccountID: account.Id, + Role: types.UserRoleAdmin, + } + + permissionsManager := permissions.NewManager(s) + am := DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + } + + settings, err := s.GetAccountSettings(context.Background(), store.LockingStrengthNone, account.Id) + require.NoError(t, err) + + groups, err := s.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id) + require.NoError(t, err) + groupsMap := make(map[string]*types.Group, len(groups)) + for _, g := range groups { + groupsMap[g.ID] = g + } + + update := &types.User{ + Id: targetID, + Role: types.UserRoleAdmin, + } + + err = s.ExecuteInTransaction(context.Background(), func(tx store.Store) error { + _, _, _, _, txErr := am.processUserUpdate( + context.Background(), tx, groupsMap, account.Id, adminID, staleInitiator, update, false, settings, + ) + return txErr + }) + + require.Error(t, err, "processUserUpdate should reject stale initiator whose role was demoted") + assert.Contains(t, err.Error(), "only admins and owners can update users") + + targetUser, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetID) + require.NoError(t, err) + assert.Equal(t, types.UserRoleUser, targetUser.Role) +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 47e735e22..849a41ace 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -166,6 +166,18 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { // testAccessLogManager provides access log storage for testing. type testAccessLogManager struct{} +func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) { + return 0, nil +} + +func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) { + // noop +} + +func (m *testAccessLogManager) StopPeriodicCleanup() { + // noop +} + func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error { return nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 1f4a163e5..b0ce1b5cc 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -7409,6 +7409,20 @@ paths: minimum: 1 maximum: 100 description: Number of items per page (max 100) + - in: query + name: sort_by + schema: + type: string + enum: [timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason] + default: timestamp + description: Field to sort by (url sorts by host then path) + - in: query + name: sort_order + schema: + type: string + enum: [asc, desc] + default: desc + description: Sort order (ascending or descending) - in: query name: search schema: