Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy

This commit is contained in:
pascal
2026-02-20 00:35:10 +01:00
21 changed files with 859 additions and 9 deletions

View File

@@ -3,6 +3,7 @@ package accesslogs
import (
"net/http"
"strconv"
"strings"
"time"
)
@@ -11,15 +12,39 @@ const (
DefaultPageSize = 50
// MaxPageSize is the maximum number of records allowed per page
MaxPageSize = 100
// Default sorting
DefaultSortBy = "timestamp"
DefaultSortOrder = "desc"
)
// AccessLogFilter holds pagination and filtering parameters for access logs
// Valid sortable fields mapped to their database column names or expressions
// For multi-column sorts, columns are separated by comma (e.g., "host, path")
var validSortFields = map[string]string{
"timestamp": "timestamp",
"url": "host, path", // Sort by host first, then path
"host": "host",
"path": "path",
"method": "method",
"status_code": "status_code",
"duration": "duration",
"source_ip": "location_connection_ip",
"user_id": "user_id",
"auth_method": "auth_method_used",
"reason": "reason",
}
// AccessLogFilter holds pagination, filtering, and sorting parameters for access logs
type AccessLogFilter struct {
// Page is the current page number (1-indexed)
Page int
// PageSize is the number of records per page
PageSize int
// Sorting parameters
SortBy string // Field to sort by: timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason
SortOrder string // Sort order: asc or desc (default: desc)
// Filtering parameters
Search *string // General search across log ID, host, path, source IP, and user fields
SourceIP *string // Filter by source IP address
@@ -35,13 +60,16 @@ type AccessLogFilter struct {
EndDate *time.Time // Filter by timestamp <= end_date
}
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
// ParseFromRequest parses pagination, sorting, and filter parameters from HTTP request query parameters
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
queryParams := r.URL.Query()
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
f.SortBy = parseSortField(queryParams.Get("sort_by"))
f.SortOrder = parseSortOrder(queryParams.Get("sort_order"))
f.Search = parseOptionalString(queryParams.Get("search"))
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
f.Host = parseOptionalString(queryParams.Get("host"))
@@ -107,3 +135,44 @@ func (f *AccessLogFilter) GetOffset() int {
func (f *AccessLogFilter) GetLimit() int {
return f.PageSize
}
// GetSortColumn returns the validated database column name for sorting
func (f *AccessLogFilter) GetSortColumn() string {
if column, ok := validSortFields[f.SortBy]; ok {
return column
}
return validSortFields[DefaultSortBy]
}
// GetSortOrder returns the validated sort order (ASC or DESC)
func (f *AccessLogFilter) GetSortOrder() string {
if f.SortOrder == "asc" || f.SortOrder == "desc" {
return f.SortOrder
}
return DefaultSortOrder
}
// parseSortField validates and returns the sort field, defaulting if invalid
func parseSortField(s string) string {
if s == "" {
return DefaultSortBy
}
// Check if the field is valid
if _, ok := validSortFields[s]; ok {
return s
}
return DefaultSortBy
}
// parseSortOrder validates and returns the sort order, defaulting if invalid
func parseSortOrder(s string) string {
if s == "" {
return DefaultSortOrder
}
// Normalize to lowercase
s = strings.ToLower(s)
if s == "asc" || s == "desc" {
return s
}
return DefaultSortOrder
}

View File

@@ -361,6 +361,205 @@ func TestParseOptionalRFC3339(t *testing.T) {
}
}
func TestAccessLogFilter_SortingDefaults(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, DefaultSortBy, filter.SortBy, "SortBy should default to timestamp")
assert.Equal(t, DefaultSortOrder, filter.SortOrder, "SortOrder should default to desc")
assert.Equal(t, "timestamp", filter.GetSortColumn(), "GetSortColumn should return timestamp")
assert.Equal(t, "desc", filter.GetSortOrder(), "GetSortOrder should return desc")
}
func TestAccessLogFilter_ValidSortFields(t *testing.T) {
tests := []struct {
name string
sortBy string
expectedColumn string
expectedSortByVal string
}{
{"timestamp", "timestamp", "timestamp", "timestamp"},
{"url", "url", "host, path", "url"},
{"host", "host", "host", "host"},
{"path", "path", "path", "path"},
{"method", "method", "method", "method"},
{"status_code", "status_code", "status_code", "status_code"},
{"duration", "duration", "duration", "duration"},
{"source_ip", "source_ip", "location_connection_ip", "source_ip"},
{"user_id", "user_id", "user_id", "user_id"},
{"auth_method", "auth_method", "auth_method_used", "auth_method"},
{"reason", "reason", "reason", "reason"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy, nil)
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expectedSortByVal, filter.SortBy, "SortBy mismatch")
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn(), "GetSortColumn mismatch")
})
}
}
func TestAccessLogFilter_InvalidSortField(t *testing.T) {
tests := []struct {
name string
sortBy string
expected string
}{
{"invalid field", "invalid_field", DefaultSortBy},
{"empty field", "", DefaultSortBy},
{"malicious input", "timestamp--DROP", DefaultSortBy},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
q := req.URL.Query()
q.Set("sort_by", tt.sortBy)
req.URL.RawQuery = q.Encode()
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expected, filter.SortBy, "Invalid sort field should default to timestamp")
assert.Equal(t, validSortFields[DefaultSortBy], filter.GetSortColumn())
})
}
}
func TestAccessLogFilter_SortOrder(t *testing.T) {
tests := []struct {
name string
sortOrder string
expected string
}{
{"ascending", "asc", "asc"},
{"descending", "desc", "desc"},
{"uppercase ASC", "ASC", "asc"},
{"uppercase DESC", "DESC", "desc"},
{"mixed case Asc", "Asc", "asc"},
{"invalid order", "invalid", DefaultSortOrder},
{"empty order", "", DefaultSortOrder},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test?sort_order="+tt.sortOrder, nil)
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expected, filter.GetSortOrder(), "GetSortOrder mismatch")
})
}
}
func TestAccessLogFilter_CompleteSortingScenarios(t *testing.T) {
tests := []struct {
name string
sortBy string
sortOrder string
expectedColumn string
expectedOrder string
}{
{
name: "sort by host ascending",
sortBy: "host",
sortOrder: "asc",
expectedColumn: "host",
expectedOrder: "asc",
},
{
name: "sort by duration descending",
sortBy: "duration",
sortOrder: "desc",
expectedColumn: "duration",
expectedOrder: "desc",
},
{
name: "sort by status_code ascending",
sortBy: "status_code",
sortOrder: "asc",
expectedColumn: "status_code",
expectedOrder: "asc",
},
{
name: "invalid sort with valid order",
sortBy: "invalid",
sortOrder: "asc",
expectedColumn: "timestamp",
expectedOrder: "asc",
},
{
name: "valid sort with invalid order",
sortBy: "method",
sortOrder: "invalid",
expectedColumn: "method",
expectedOrder: DefaultSortOrder,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy+"&sort_order="+tt.sortOrder, nil)
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn())
assert.Equal(t, tt.expectedOrder, filter.GetSortOrder())
})
}
}
func TestParseSortField(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"valid field", "host", "host"},
{"empty string", "", DefaultSortBy},
{"invalid field", "invalid", DefaultSortBy},
{"malicious input", "timestamp--DROP", DefaultSortBy},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseSortField(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseSortOrder(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"asc lowercase", "asc", "asc"},
{"desc lowercase", "desc", "desc"},
{"ASC uppercase", "ASC", "asc"},
{"DESC uppercase", "DESC", "desc"},
{"invalid", "invalid", DefaultSortOrder},
{"empty", "", DefaultSortOrder},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseSortOrder(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// Helper functions for creating pointers
func strPtr(s string) *string {
return &s

View File

@@ -7,4 +7,7 @@ import (
type Manager interface {
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error)
StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int)
StopPeriodicCleanup()
}

View File

@@ -3,6 +3,7 @@ package manager
import (
"context"
"strings"
"time"
log "github.com/sirupsen/logrus"
@@ -19,6 +20,7 @@ type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
geo geolocation.Geolocation
cleanupCancel context.CancelFunc
}
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
@@ -78,6 +80,74 @@ func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID st
return logs, totalCount, nil
}
// CleanupOldAccessLogs deletes access logs older than the specified retention period
func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
if retentionDays <= 0 {
log.WithContext(ctx).Debug("access log cleanup skipped: retention days is 0 or negative")
return 0, nil
}
cutoffTime := time.Now().AddDate(0, 0, -retentionDays)
deletedCount, err := m.store.DeleteOldAccessLogs(ctx, cutoffTime)
if err != nil {
log.WithContext(ctx).Errorf("failed to cleanup old access logs: %v", err)
return 0, err
}
if deletedCount > 0 {
log.WithContext(ctx).Infof("cleaned up %d access logs older than %d days", deletedCount, retentionDays)
}
return deletedCount, nil
}
// StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs
func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
if retentionDays <= 0 {
log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative")
return
}
if cleanupIntervalHours <= 0 {
cleanupIntervalHours = 24
}
cleanupCtx, cancel := context.WithCancel(ctx)
m.cleanupCancel = cancel
cleanupInterval := time.Duration(cleanupIntervalHours) * time.Hour
ticker := time.NewTicker(cleanupInterval)
go func() {
defer ticker.Stop()
// Run cleanup immediately on startup
log.WithContext(cleanupCtx).Infof("starting access log cleanup routine (retention: %d days, interval: %d hours)", retentionDays, cleanupIntervalHours)
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
log.WithContext(cleanupCtx).Errorf("initial access log cleanup failed: %v", err)
}
for {
select {
case <-cleanupCtx.Done():
log.WithContext(cleanupCtx).Info("stopping access log cleanup routine")
return
case <-ticker.C:
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
log.WithContext(cleanupCtx).Errorf("periodic access log cleanup failed: %v", err)
}
}
}
}()
}
// StopPeriodicCleanup stops the periodic cleanup routine
func (m *managerImpl) StopPeriodicCleanup() {
if m.cleanupCancel != nil {
m.cleanupCancel()
}
}
// resolveUserFilters converts user email/name filters to user ID filter
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
if filter.UserEmail == nil && filter.UserName == nil {

View File

@@ -0,0 +1,281 @@
package manager
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/store"
)
func TestCleanupOldAccessLogs(t *testing.T) {
tests := []struct {
name string
retentionDays int
setupMock func(*store.MockStore)
expectedCount int64
expectedError bool
}{
{
name: "cleanup logs older than retention period",
retentionDays: 30,
setupMock: func(mockStore *store.MockStore) {
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
expectedCutoff := time.Now().AddDate(0, 0, -30)
timeDiff := olderThan.Sub(expectedCutoff)
if timeDiff.Abs() > time.Second {
t.Errorf("cutoff time not as expected: got %v, want ~%v", olderThan, expectedCutoff)
}
return 5, nil
})
},
expectedCount: 5,
expectedError: false,
},
{
name: "no logs to cleanup",
retentionDays: 30,
setupMock: func(mockStore *store.MockStore) {
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(0), nil)
},
expectedCount: 0,
expectedError: false,
},
{
name: "zero retention days skips cleanup",
retentionDays: 0,
setupMock: func(mockStore *store.MockStore) {
// No expectations - DeleteOldAccessLogs should not be called
},
expectedCount: 0,
expectedError: false,
},
{
name: "negative retention days skips cleanup",
retentionDays: -10,
setupMock: func(mockStore *store.MockStore) {
// No expectations - DeleteOldAccessLogs should not be called
},
expectedCount: 0,
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore)
manager := &managerImpl{
store: mockStore,
}
ctx := context.Background()
deletedCount, err := manager.CleanupOldAccessLogs(ctx, tt.retentionDays)
if tt.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.expectedCount, deletedCount, "unexpected number of deleted logs")
})
}
}
func TestCleanupWithExactBoundary(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
expectedCutoff := time.Now().AddDate(0, 0, -30)
timeDiff := olderThan.Sub(expectedCutoff)
assert.Less(t, timeDiff.Abs(), time.Second, "cutoff time should be close to expected value")
return 1, nil
})
manager := &managerImpl{
store: mockStore,
}
ctx := context.Background()
deletedCount, err := manager.CleanupOldAccessLogs(ctx, 30)
require.NoError(t, err)
assert.Equal(t, int64(1), deletedCount)
}
func TestStartPeriodicCleanup(t *testing.T) {
t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
// No expectations - cleanup should not run
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 0, 1)
time.Sleep(100 * time.Millisecond)
// If DeleteOldAccessLogs was called, the test will fail due to unexpected call
})
t.Run("periodic cleanup runs immediately on start", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(2), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 30, 24)
time.Sleep(200 * time.Millisecond)
// Expectations verified by gomock on defer ctrl.Finish()
})
t.Run("periodic cleanup stops on context cancel", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(1), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
manager.StartPeriodicCleanup(ctx, 30, 24)
time.Sleep(100 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
})
t.Run("cleanup interval defaults to 24 hours when invalid", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(0), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 30, 0)
time.Sleep(100 * time.Millisecond)
manager.StopPeriodicCleanup()
})
t.Run("cleanup interval uses configured hours", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(3), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 30, 12)
time.Sleep(100 * time.Millisecond)
manager.StopPeriodicCleanup()
})
}
func TestStopPeriodicCleanup(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(1), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx := context.Background()
manager.StartPeriodicCleanup(ctx, 30, 24)
time.Sleep(100 * time.Millisecond)
manager.StopPeriodicCleanup()
time.Sleep(200 * time.Millisecond)
// Expectations verified by gomock - would fail if more than 1 call happened
}
func TestStopPeriodicCleanup_NotStarted(t *testing.T) {
manager := &managerImpl{}
// Should not panic if cleanup was never started
manager.StopPeriodicCleanup()
}