mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy
This commit is contained in:
@@ -3,6 +3,7 @@ package accesslogs
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -11,15 +12,39 @@ const (
|
||||
DefaultPageSize = 50
|
||||
// MaxPageSize is the maximum number of records allowed per page
|
||||
MaxPageSize = 100
|
||||
|
||||
// Default sorting
|
||||
DefaultSortBy = "timestamp"
|
||||
DefaultSortOrder = "desc"
|
||||
)
|
||||
|
||||
// AccessLogFilter holds pagination and filtering parameters for access logs
|
||||
// Valid sortable fields mapped to their database column names or expressions
|
||||
// For multi-column sorts, columns are separated by comma (e.g., "host, path")
|
||||
var validSortFields = map[string]string{
|
||||
"timestamp": "timestamp",
|
||||
"url": "host, path", // Sort by host first, then path
|
||||
"host": "host",
|
||||
"path": "path",
|
||||
"method": "method",
|
||||
"status_code": "status_code",
|
||||
"duration": "duration",
|
||||
"source_ip": "location_connection_ip",
|
||||
"user_id": "user_id",
|
||||
"auth_method": "auth_method_used",
|
||||
"reason": "reason",
|
||||
}
|
||||
|
||||
// AccessLogFilter holds pagination, filtering, and sorting parameters for access logs
|
||||
type AccessLogFilter struct {
|
||||
// Page is the current page number (1-indexed)
|
||||
Page int
|
||||
// PageSize is the number of records per page
|
||||
PageSize int
|
||||
|
||||
// Sorting parameters
|
||||
SortBy string // Field to sort by: timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason
|
||||
SortOrder string // Sort order: asc or desc (default: desc)
|
||||
|
||||
// Filtering parameters
|
||||
Search *string // General search across log ID, host, path, source IP, and user fields
|
||||
SourceIP *string // Filter by source IP address
|
||||
@@ -35,13 +60,16 @@ type AccessLogFilter struct {
|
||||
EndDate *time.Time // Filter by timestamp <= end_date
|
||||
}
|
||||
|
||||
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
|
||||
// ParseFromRequest parses pagination, sorting, and filter parameters from HTTP request query parameters
|
||||
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||
queryParams := r.URL.Query()
|
||||
|
||||
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
|
||||
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
|
||||
|
||||
f.SortBy = parseSortField(queryParams.Get("sort_by"))
|
||||
f.SortOrder = parseSortOrder(queryParams.Get("sort_order"))
|
||||
|
||||
f.Search = parseOptionalString(queryParams.Get("search"))
|
||||
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
|
||||
f.Host = parseOptionalString(queryParams.Get("host"))
|
||||
@@ -107,3 +135,44 @@ func (f *AccessLogFilter) GetOffset() int {
|
||||
func (f *AccessLogFilter) GetLimit() int {
|
||||
return f.PageSize
|
||||
}
|
||||
|
||||
// GetSortColumn returns the validated database column name for sorting
|
||||
func (f *AccessLogFilter) GetSortColumn() string {
|
||||
if column, ok := validSortFields[f.SortBy]; ok {
|
||||
return column
|
||||
}
|
||||
return validSortFields[DefaultSortBy]
|
||||
}
|
||||
|
||||
// GetSortOrder returns the validated sort order (ASC or DESC)
|
||||
func (f *AccessLogFilter) GetSortOrder() string {
|
||||
if f.SortOrder == "asc" || f.SortOrder == "desc" {
|
||||
return f.SortOrder
|
||||
}
|
||||
return DefaultSortOrder
|
||||
}
|
||||
|
||||
// parseSortField validates and returns the sort field, defaulting if invalid
|
||||
func parseSortField(s string) string {
|
||||
if s == "" {
|
||||
return DefaultSortBy
|
||||
}
|
||||
// Check if the field is valid
|
||||
if _, ok := validSortFields[s]; ok {
|
||||
return s
|
||||
}
|
||||
return DefaultSortBy
|
||||
}
|
||||
|
||||
// parseSortOrder validates and returns the sort order, defaulting if invalid
|
||||
func parseSortOrder(s string) string {
|
||||
if s == "" {
|
||||
return DefaultSortOrder
|
||||
}
|
||||
// Normalize to lowercase
|
||||
s = strings.ToLower(s)
|
||||
if s == "asc" || s == "desc" {
|
||||
return s
|
||||
}
|
||||
return DefaultSortOrder
|
||||
}
|
||||
|
||||
@@ -361,6 +361,205 @@ func TestParseOptionalRFC3339(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_SortingDefaults(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, DefaultSortBy, filter.SortBy, "SortBy should default to timestamp")
|
||||
assert.Equal(t, DefaultSortOrder, filter.SortOrder, "SortOrder should default to desc")
|
||||
assert.Equal(t, "timestamp", filter.GetSortColumn(), "GetSortColumn should return timestamp")
|
||||
assert.Equal(t, "desc", filter.GetSortOrder(), "GetSortOrder should return desc")
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ValidSortFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
expectedColumn string
|
||||
expectedSortByVal string
|
||||
}{
|
||||
{"timestamp", "timestamp", "timestamp", "timestamp"},
|
||||
{"url", "url", "host, path", "url"},
|
||||
{"host", "host", "host", "host"},
|
||||
{"path", "path", "path", "path"},
|
||||
{"method", "method", "method", "method"},
|
||||
{"status_code", "status_code", "status_code", "status_code"},
|
||||
{"duration", "duration", "duration", "duration"},
|
||||
{"source_ip", "source_ip", "location_connection_ip", "source_ip"},
|
||||
{"user_id", "user_id", "user_id", "user_id"},
|
||||
{"auth_method", "auth_method", "auth_method_used", "auth_method"},
|
||||
{"reason", "reason", "reason", "reason"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expectedSortByVal, filter.SortBy, "SortBy mismatch")
|
||||
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn(), "GetSortColumn mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_InvalidSortField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
expected string
|
||||
}{
|
||||
{"invalid field", "invalid_field", DefaultSortBy},
|
||||
{"empty field", "", DefaultSortBy},
|
||||
{"malicious input", "timestamp--DROP", DefaultSortBy},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
q := req.URL.Query()
|
||||
q.Set("sort_by", tt.sortBy)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expected, filter.SortBy, "Invalid sort field should default to timestamp")
|
||||
assert.Equal(t, validSortFields[DefaultSortBy], filter.GetSortColumn())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_SortOrder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortOrder string
|
||||
expected string
|
||||
}{
|
||||
{"ascending", "asc", "asc"},
|
||||
{"descending", "desc", "desc"},
|
||||
{"uppercase ASC", "ASC", "asc"},
|
||||
{"uppercase DESC", "DESC", "desc"},
|
||||
{"mixed case Asc", "Asc", "asc"},
|
||||
{"invalid order", "invalid", DefaultSortOrder},
|
||||
{"empty order", "", DefaultSortOrder},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_order="+tt.sortOrder, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expected, filter.GetSortOrder(), "GetSortOrder mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_CompleteSortingScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
sortOrder string
|
||||
expectedColumn string
|
||||
expectedOrder string
|
||||
}{
|
||||
{
|
||||
name: "sort by host ascending",
|
||||
sortBy: "host",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "host",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "sort by duration descending",
|
||||
sortBy: "duration",
|
||||
sortOrder: "desc",
|
||||
expectedColumn: "duration",
|
||||
expectedOrder: "desc",
|
||||
},
|
||||
{
|
||||
name: "sort by status_code ascending",
|
||||
sortBy: "status_code",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "status_code",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "invalid sort with valid order",
|
||||
sortBy: "invalid",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "timestamp",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "valid sort with invalid order",
|
||||
sortBy: "method",
|
||||
sortOrder: "invalid",
|
||||
expectedColumn: "method",
|
||||
expectedOrder: DefaultSortOrder,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy+"&sort_order="+tt.sortOrder, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn())
|
||||
assert.Equal(t, tt.expectedOrder, filter.GetSortOrder())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"valid field", "host", "host"},
|
||||
{"empty string", "", DefaultSortBy},
|
||||
{"invalid field", "invalid", DefaultSortBy},
|
||||
{"malicious input", "timestamp--DROP", DefaultSortBy},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseSortField(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortOrder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"asc lowercase", "asc", "asc"},
|
||||
{"desc lowercase", "desc", "desc"},
|
||||
{"ASC uppercase", "ASC", "asc"},
|
||||
{"DESC uppercase", "DESC", "desc"},
|
||||
{"invalid", "invalid", DefaultSortOrder},
|
||||
{"empty", "", DefaultSortOrder},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseSortOrder(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating pointers
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
|
||||
@@ -7,4 +7,7 @@ import (
|
||||
type Manager interface {
|
||||
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
|
||||
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
|
||||
CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error)
|
||||
StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int)
|
||||
StopPeriodicCleanup()
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package manager
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -19,6 +20,7 @@ type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
geo geolocation.Geolocation
|
||||
cleanupCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||
@@ -78,6 +80,74 @@ func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID st
|
||||
return logs, totalCount, nil
|
||||
}
|
||||
|
||||
// CleanupOldAccessLogs deletes access logs older than the specified retention period
|
||||
func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
|
||||
if retentionDays <= 0 {
|
||||
log.WithContext(ctx).Debug("access log cleanup skipped: retention days is 0 or negative")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
cutoffTime := time.Now().AddDate(0, 0, -retentionDays)
|
||||
deletedCount, err := m.store.DeleteOldAccessLogs(ctx, cutoffTime)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup old access logs: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
log.WithContext(ctx).Infof("cleaned up %d access logs older than %d days", deletedCount, retentionDays)
|
||||
}
|
||||
|
||||
return deletedCount, nil
|
||||
}
|
||||
|
||||
// StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs
|
||||
func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
|
||||
if retentionDays <= 0 {
|
||||
log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative")
|
||||
return
|
||||
}
|
||||
|
||||
if cleanupIntervalHours <= 0 {
|
||||
cleanupIntervalHours = 24
|
||||
}
|
||||
|
||||
cleanupCtx, cancel := context.WithCancel(ctx)
|
||||
m.cleanupCancel = cancel
|
||||
|
||||
cleanupInterval := time.Duration(cleanupIntervalHours) * time.Hour
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run cleanup immediately on startup
|
||||
log.WithContext(cleanupCtx).Infof("starting access log cleanup routine (retention: %d days, interval: %d hours)", retentionDays, cleanupIntervalHours)
|
||||
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
|
||||
log.WithContext(cleanupCtx).Errorf("initial access log cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cleanupCtx.Done():
|
||||
log.WithContext(cleanupCtx).Info("stopping access log cleanup routine")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
|
||||
log.WithContext(cleanupCtx).Errorf("periodic access log cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StopPeriodicCleanup stops the periodic cleanup routine
|
||||
func (m *managerImpl) StopPeriodicCleanup() {
|
||||
if m.cleanupCancel != nil {
|
||||
m.cleanupCancel()
|
||||
}
|
||||
}
|
||||
|
||||
// resolveUserFilters converts user email/name filters to user ID filter
|
||||
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
||||
if filter.UserEmail == nil && filter.UserName == nil {
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
func TestCleanupOldAccessLogs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
retentionDays int
|
||||
setupMock func(*store.MockStore)
|
||||
expectedCount int64
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "cleanup logs older than retention period",
|
||||
retentionDays: 30,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
expectedCutoff := time.Now().AddDate(0, 0, -30)
|
||||
timeDiff := olderThan.Sub(expectedCutoff)
|
||||
if timeDiff.Abs() > time.Second {
|
||||
t.Errorf("cutoff time not as expected: got %v, want ~%v", olderThan, expectedCutoff)
|
||||
}
|
||||
return 5, nil
|
||||
})
|
||||
},
|
||||
expectedCount: 5,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "no logs to cleanup",
|
||||
retentionDays: 30,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(0), nil)
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "zero retention days skips cleanup",
|
||||
retentionDays: 0,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
// No expectations - DeleteOldAccessLogs should not be called
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "negative retention days skips cleanup",
|
||||
retentionDays: -10,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
// No expectations - DeleteOldAccessLogs should not be called
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
deletedCount, err := manager.CleanupOldAccessLogs(ctx, tt.retentionDays)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.expectedCount, deletedCount, "unexpected number of deleted logs")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupWithExactBoundary(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
expectedCutoff := time.Now().AddDate(0, 0, -30)
|
||||
timeDiff := olderThan.Sub(expectedCutoff)
|
||||
assert.Less(t, timeDiff.Abs(), time.Second, "cutoff time should be close to expected value")
|
||||
return 1, nil
|
||||
})
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
deletedCount, err := manager.CleanupOldAccessLogs(ctx, 30)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), deletedCount)
|
||||
}
|
||||
|
||||
func TestStartPeriodicCleanup(t *testing.T) {
|
||||
t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
// No expectations - cleanup should not run
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 0, 1)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// If DeleteOldAccessLogs was called, the test will fail due to unexpected call
|
||||
})
|
||||
|
||||
t.Run("periodic cleanup runs immediately on start", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(2), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Expectations verified by gomock on defer ctrl.Finish()
|
||||
})
|
||||
|
||||
t.Run("periodic cleanup stops on context cancel", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(1), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
})
|
||||
|
||||
t.Run("cleanup interval defaults to 24 hours when invalid", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(0), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 0)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
})
|
||||
|
||||
t.Run("cleanup interval uses configured hours", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(3), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 12)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopPeriodicCleanup(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(1), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Expectations verified by gomock - would fail if more than 1 call happened
|
||||
}
|
||||
|
||||
func TestStopPeriodicCleanup_NotStarted(t *testing.T) {
|
||||
manager := &managerImpl{}
|
||||
|
||||
// Should not panic if cleanup was never started
|
||||
manager.StopPeriodicCleanup()
|
||||
}
|
||||
Reference in New Issue
Block a user