Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy

This commit is contained in:
pascal
2026-02-20 00:35:10 +01:00
21 changed files with 859 additions and 9 deletions

View File

@@ -28,8 +28,8 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
@@ -1562,8 +1562,10 @@ func (e *Engine) receiveSignalEvents() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
start := time.Now()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
gotLock := time.Since(start)
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
@@ -1587,6 +1589,8 @@ func (e *Engine) receiveSignalEvents() {
return err
}
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
if msg.Body.Type == sProto.Body_OFFER {
conn.OnRemoteOffer(*offerAnswer)
} else {

View File

@@ -351,6 +351,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
logger.Errorf("failed to update domain prefixes: %v", err)
}
// Allow time for route changes to be applied before sending
// the DNS response (relevant on iOS where setTunnelNetworkSettings
// is asynchronous).
waitForRouteSettlement(logger)
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
}
}

View File

@@ -0,0 +1,20 @@
//go:build ios
package dnsinterceptor
import (
"time"
log "github.com/sirupsen/logrus"
)
const routeSettleDelay = 500 * time.Millisecond
// waitForRouteSettlement introduces a short delay on iOS to allow
// setTunnelNetworkSettings to apply route changes before the DNS
// response reaches the application. Without this, the first request
// to a newly resolved domain may bypass the tunnel.
func waitForRouteSettlement(logger *log.Entry) {
logger.Tracef("waiting %v for iOS route settlement", routeSettleDelay)
time.Sleep(routeSettleDelay)
}

View File

@@ -0,0 +1,12 @@
//go:build !ios
package dnsinterceptor
import log "github.com/sirupsen/logrus"
func waitForRouteSettlement(_ *log.Entry) {
// No-op on non-iOS platforms: route changes are applied synchronously by
// the kernel, so no settlement delay is needed before the DNS response
// reaches the application. The delay is only required on iOS where
// setTunnelNetworkSettings applies routes asynchronously.
}

View File

@@ -1,8 +1,6 @@
package txt
import (
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/formatter/levels"
@@ -18,7 +16,7 @@ type TextFormatter struct {
func NewTextFormatter() *TextFormatter {
return &TextFormatter{
levelDesc: levels.ValidLevelDesc,
timestampFormat: time.RFC3339, // or RFC3339
timestampFormat: "2006-01-02T15:04:05.000Z07:00",
}
}

View File

@@ -21,6 +21,6 @@ func TestLogTextFormat(t *testing.T) {
result, _ := formatter.Format(someEntry)
parsedString := string(result)
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
expectedString := "^2021-02-21T01:10:30.000Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}

View File

@@ -3,6 +3,7 @@ package accesslogs
import (
"net/http"
"strconv"
"strings"
"time"
)
@@ -11,15 +12,39 @@ const (
DefaultPageSize = 50
// MaxPageSize is the maximum number of records allowed per page
MaxPageSize = 100
// Default sorting
DefaultSortBy = "timestamp"
DefaultSortOrder = "desc"
)
// AccessLogFilter holds pagination and filtering parameters for access logs
// Valid sortable fields mapped to their database column names or expressions
// For multi-column sorts, columns are separated by comma (e.g., "host, path")
var validSortFields = map[string]string{
"timestamp": "timestamp",
"url": "host, path", // Sort by host first, then path
"host": "host",
"path": "path",
"method": "method",
"status_code": "status_code",
"duration": "duration",
"source_ip": "location_connection_ip",
"user_id": "user_id",
"auth_method": "auth_method_used",
"reason": "reason",
}
// AccessLogFilter holds pagination, filtering, and sorting parameters for access logs
type AccessLogFilter struct {
// Page is the current page number (1-indexed)
Page int
// PageSize is the number of records per page
PageSize int
// Sorting parameters
SortBy string // Field to sort by: timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason
SortOrder string // Sort order: asc or desc (default: desc)
// Filtering parameters
Search *string // General search across log ID, host, path, source IP, and user fields
SourceIP *string // Filter by source IP address
@@ -35,13 +60,16 @@ type AccessLogFilter struct {
EndDate *time.Time // Filter by timestamp <= end_date
}
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
// ParseFromRequest parses pagination, sorting, and filter parameters from HTTP request query parameters
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
queryParams := r.URL.Query()
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
f.SortBy = parseSortField(queryParams.Get("sort_by"))
f.SortOrder = parseSortOrder(queryParams.Get("sort_order"))
f.Search = parseOptionalString(queryParams.Get("search"))
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
f.Host = parseOptionalString(queryParams.Get("host"))
@@ -107,3 +135,44 @@ func (f *AccessLogFilter) GetOffset() int {
func (f *AccessLogFilter) GetLimit() int {
return f.PageSize
}
// GetSortColumn returns the validated database column name for sorting
func (f *AccessLogFilter) GetSortColumn() string {
if column, ok := validSortFields[f.SortBy]; ok {
return column
}
return validSortFields[DefaultSortBy]
}
// GetSortOrder returns the validated sort order (ASC or DESC)
func (f *AccessLogFilter) GetSortOrder() string {
if f.SortOrder == "asc" || f.SortOrder == "desc" {
return f.SortOrder
}
return DefaultSortOrder
}
// parseSortField validates and returns the sort field, defaulting if invalid
func parseSortField(s string) string {
if s == "" {
return DefaultSortBy
}
// Check if the field is valid
if _, ok := validSortFields[s]; ok {
return s
}
return DefaultSortBy
}
// parseSortOrder validates and returns the sort order, defaulting if invalid
func parseSortOrder(s string) string {
if s == "" {
return DefaultSortOrder
}
// Normalize to lowercase
s = strings.ToLower(s)
if s == "asc" || s == "desc" {
return s
}
return DefaultSortOrder
}

View File

@@ -361,6 +361,205 @@ func TestParseOptionalRFC3339(t *testing.T) {
}
}
func TestAccessLogFilter_SortingDefaults(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, DefaultSortBy, filter.SortBy, "SortBy should default to timestamp")
assert.Equal(t, DefaultSortOrder, filter.SortOrder, "SortOrder should default to desc")
assert.Equal(t, "timestamp", filter.GetSortColumn(), "GetSortColumn should return timestamp")
assert.Equal(t, "desc", filter.GetSortOrder(), "GetSortOrder should return desc")
}
func TestAccessLogFilter_ValidSortFields(t *testing.T) {
tests := []struct {
name string
sortBy string
expectedColumn string
expectedSortByVal string
}{
{"timestamp", "timestamp", "timestamp", "timestamp"},
{"url", "url", "host, path", "url"},
{"host", "host", "host", "host"},
{"path", "path", "path", "path"},
{"method", "method", "method", "method"},
{"status_code", "status_code", "status_code", "status_code"},
{"duration", "duration", "duration", "duration"},
{"source_ip", "source_ip", "location_connection_ip", "source_ip"},
{"user_id", "user_id", "user_id", "user_id"},
{"auth_method", "auth_method", "auth_method_used", "auth_method"},
{"reason", "reason", "reason", "reason"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy, nil)
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expectedSortByVal, filter.SortBy, "SortBy mismatch")
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn(), "GetSortColumn mismatch")
})
}
}
func TestAccessLogFilter_InvalidSortField(t *testing.T) {
tests := []struct {
name string
sortBy string
expected string
}{
{"invalid field", "invalid_field", DefaultSortBy},
{"empty field", "", DefaultSortBy},
{"malicious input", "timestamp--DROP", DefaultSortBy},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
q := req.URL.Query()
q.Set("sort_by", tt.sortBy)
req.URL.RawQuery = q.Encode()
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expected, filter.SortBy, "Invalid sort field should default to timestamp")
assert.Equal(t, validSortFields[DefaultSortBy], filter.GetSortColumn())
})
}
}
func TestAccessLogFilter_SortOrder(t *testing.T) {
tests := []struct {
name string
sortOrder string
expected string
}{
{"ascending", "asc", "asc"},
{"descending", "desc", "desc"},
{"uppercase ASC", "ASC", "asc"},
{"uppercase DESC", "DESC", "desc"},
{"mixed case Asc", "Asc", "asc"},
{"invalid order", "invalid", DefaultSortOrder},
{"empty order", "", DefaultSortOrder},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test?sort_order="+tt.sortOrder, nil)
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expected, filter.GetSortOrder(), "GetSortOrder mismatch")
})
}
}
func TestAccessLogFilter_CompleteSortingScenarios(t *testing.T) {
tests := []struct {
name string
sortBy string
sortOrder string
expectedColumn string
expectedOrder string
}{
{
name: "sort by host ascending",
sortBy: "host",
sortOrder: "asc",
expectedColumn: "host",
expectedOrder: "asc",
},
{
name: "sort by duration descending",
sortBy: "duration",
sortOrder: "desc",
expectedColumn: "duration",
expectedOrder: "desc",
},
{
name: "sort by status_code ascending",
sortBy: "status_code",
sortOrder: "asc",
expectedColumn: "status_code",
expectedOrder: "asc",
},
{
name: "invalid sort with valid order",
sortBy: "invalid",
sortOrder: "asc",
expectedColumn: "timestamp",
expectedOrder: "asc",
},
{
name: "valid sort with invalid order",
sortBy: "method",
sortOrder: "invalid",
expectedColumn: "method",
expectedOrder: DefaultSortOrder,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy+"&sort_order="+tt.sortOrder, nil)
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn())
assert.Equal(t, tt.expectedOrder, filter.GetSortOrder())
})
}
}
func TestParseSortField(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"valid field", "host", "host"},
{"empty string", "", DefaultSortBy},
{"invalid field", "invalid", DefaultSortBy},
{"malicious input", "timestamp--DROP", DefaultSortBy},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseSortField(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseSortOrder(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"asc lowercase", "asc", "asc"},
{"desc lowercase", "desc", "desc"},
{"ASC uppercase", "ASC", "asc"},
{"DESC uppercase", "DESC", "desc"},
{"invalid", "invalid", DefaultSortOrder},
{"empty", "", DefaultSortOrder},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseSortOrder(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// Helper functions for creating pointers
func strPtr(s string) *string {
return &s

View File

@@ -7,4 +7,7 @@ import (
type Manager interface {
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error)
StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int)
StopPeriodicCleanup()
}

View File

@@ -3,6 +3,7 @@ package manager
import (
"context"
"strings"
"time"
log "github.com/sirupsen/logrus"
@@ -19,6 +20,7 @@ type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
geo geolocation.Geolocation
cleanupCancel context.CancelFunc
}
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
@@ -78,6 +80,74 @@ func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID st
return logs, totalCount, nil
}
// CleanupOldAccessLogs deletes access logs older than the specified retention period
func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
if retentionDays <= 0 {
log.WithContext(ctx).Debug("access log cleanup skipped: retention days is 0 or negative")
return 0, nil
}
cutoffTime := time.Now().AddDate(0, 0, -retentionDays)
deletedCount, err := m.store.DeleteOldAccessLogs(ctx, cutoffTime)
if err != nil {
log.WithContext(ctx).Errorf("failed to cleanup old access logs: %v", err)
return 0, err
}
if deletedCount > 0 {
log.WithContext(ctx).Infof("cleaned up %d access logs older than %d days", deletedCount, retentionDays)
}
return deletedCount, nil
}
// StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs
func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
if retentionDays <= 0 {
log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative")
return
}
if cleanupIntervalHours <= 0 {
cleanupIntervalHours = 24
}
cleanupCtx, cancel := context.WithCancel(ctx)
m.cleanupCancel = cancel
cleanupInterval := time.Duration(cleanupIntervalHours) * time.Hour
ticker := time.NewTicker(cleanupInterval)
go func() {
defer ticker.Stop()
// Run cleanup immediately on startup
log.WithContext(cleanupCtx).Infof("starting access log cleanup routine (retention: %d days, interval: %d hours)", retentionDays, cleanupIntervalHours)
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
log.WithContext(cleanupCtx).Errorf("initial access log cleanup failed: %v", err)
}
for {
select {
case <-cleanupCtx.Done():
log.WithContext(cleanupCtx).Info("stopping access log cleanup routine")
return
case <-ticker.C:
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
log.WithContext(cleanupCtx).Errorf("periodic access log cleanup failed: %v", err)
}
}
}
}()
}
// StopPeriodicCleanup stops the periodic cleanup routine
func (m *managerImpl) StopPeriodicCleanup() {
if m.cleanupCancel != nil {
m.cleanupCancel()
}
}
// resolveUserFilters converts user email/name filters to user ID filter
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
if filter.UserEmail == nil && filter.UserName == nil {

View File

@@ -0,0 +1,281 @@
package manager
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/store"
)
func TestCleanupOldAccessLogs(t *testing.T) {
tests := []struct {
name string
retentionDays int
setupMock func(*store.MockStore)
expectedCount int64
expectedError bool
}{
{
name: "cleanup logs older than retention period",
retentionDays: 30,
setupMock: func(mockStore *store.MockStore) {
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
expectedCutoff := time.Now().AddDate(0, 0, -30)
timeDiff := olderThan.Sub(expectedCutoff)
if timeDiff.Abs() > time.Second {
t.Errorf("cutoff time not as expected: got %v, want ~%v", olderThan, expectedCutoff)
}
return 5, nil
})
},
expectedCount: 5,
expectedError: false,
},
{
name: "no logs to cleanup",
retentionDays: 30,
setupMock: func(mockStore *store.MockStore) {
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(0), nil)
},
expectedCount: 0,
expectedError: false,
},
{
name: "zero retention days skips cleanup",
retentionDays: 0,
setupMock: func(mockStore *store.MockStore) {
// No expectations - DeleteOldAccessLogs should not be called
},
expectedCount: 0,
expectedError: false,
},
{
name: "negative retention days skips cleanup",
retentionDays: -10,
setupMock: func(mockStore *store.MockStore) {
// No expectations - DeleteOldAccessLogs should not be called
},
expectedCount: 0,
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore)
manager := &managerImpl{
store: mockStore,
}
ctx := context.Background()
deletedCount, err := manager.CleanupOldAccessLogs(ctx, tt.retentionDays)
if tt.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.expectedCount, deletedCount, "unexpected number of deleted logs")
})
}
}
func TestCleanupWithExactBoundary(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
expectedCutoff := time.Now().AddDate(0, 0, -30)
timeDiff := olderThan.Sub(expectedCutoff)
assert.Less(t, timeDiff.Abs(), time.Second, "cutoff time should be close to expected value")
return 1, nil
})
manager := &managerImpl{
store: mockStore,
}
ctx := context.Background()
deletedCount, err := manager.CleanupOldAccessLogs(ctx, 30)
require.NoError(t, err)
assert.Equal(t, int64(1), deletedCount)
}
func TestStartPeriodicCleanup(t *testing.T) {
t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
// No expectations - cleanup should not run
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 0, 1)
time.Sleep(100 * time.Millisecond)
// If DeleteOldAccessLogs was called, the test will fail due to unexpected call
})
t.Run("periodic cleanup runs immediately on start", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(2), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 30, 24)
time.Sleep(200 * time.Millisecond)
// Expectations verified by gomock on defer ctrl.Finish()
})
t.Run("periodic cleanup stops on context cancel", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(1), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
manager.StartPeriodicCleanup(ctx, 30, 24)
time.Sleep(100 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
})
t.Run("cleanup interval defaults to 24 hours when invalid", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(0), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 30, 0)
time.Sleep(100 * time.Millisecond)
manager.StopPeriodicCleanup()
})
t.Run("cleanup interval uses configured hours", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(3), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 30, 12)
time.Sleep(100 * time.Millisecond)
manager.StopPeriodicCleanup()
})
}
func TestStopPeriodicCleanup(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
Return(int64(1), nil).
Times(1)
manager := &managerImpl{
store: mockStore,
}
ctx := context.Background()
manager.StartPeriodicCleanup(ctx, 30, 24)
time.Sleep(100 * time.Millisecond)
manager.StopPeriodicCleanup()
time.Sleep(200 * time.Millisecond)
// Expectations verified by gomock - would fail if more than 1 call happened
}
func TestStopPeriodicCleanup_NotStarted(t *testing.T) {
manager := &managerImpl{}
// Should not panic if cleanup was never started
manager.StopPeriodicCleanup()
}

View File

@@ -200,6 +200,11 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
accessLogManager.StartPeriodicCleanup(
context.Background(),
s.Config.ReverseProxy.AccessLogRetentionDays,
s.Config.ReverseProxy.AccessLogCleanupIntervalHours,
)
return accessLogManager
})
}

View File

@@ -200,4 +200,13 @@ type ReverseProxy struct {
// request headers if the peer's address falls within one of these
// trusted IP prefixes.
TrustedPeers []netip.Prefix
// AccessLogRetentionDays specifies the number of days to retain access logs.
// Logs older than this duration will be automatically deleted during cleanup.
// A value of 0 or negative means logs are kept indefinitely (no cleanup).
AccessLogRetentionDays int
// AccessLogCleanupIntervalHours specifies how often (in hours) to run the cleanup routine.
// Defaults to 24 hours if not set or set to 0.
AccessLogCleanupIntervalHours int
}

View File

@@ -157,6 +157,18 @@ type testSetup struct {
// testAccessLogManager is a minimal mock for accesslogs.Manager.
type testAccessLogManager struct{}
func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
return 0, nil
}
func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
return
}
func (m *testAccessLogManager) StopPeriodicCleanup() {
return
}
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
return nil
}

View File

@@ -5083,8 +5083,20 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
query = s.applyAccessLogFilters(query, filter)
sortColumns := filter.GetSortColumn()
sortOrder := strings.ToUpper(filter.GetSortOrder())
var orderClauses []string
for _, col := range strings.Split(sortColumns, ",") {
col = strings.TrimSpace(col)
if col != "" {
orderClauses = append(orderClauses, col+" "+sortOrder)
}
}
orderClause := strings.Join(orderClauses, ", ")
query = query.
Order("timestamp DESC").
Order(orderClause).
Limit(filter.GetLimit()).
Offset(filter.GetOffset())
@@ -5101,6 +5113,20 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
return logs, totalCount, nil
}
// DeleteOldAccessLogs deletes all access logs older than the specified time
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
result := s.db.WithContext(ctx).
Where("timestamp < ?", olderThan).
Delete(&accesslogs.AccessLogEntry{})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete old access logs: %v", result.Error)
return 0, status.Errorf(status.Internal, "failed to delete old access logs")
}
return result.RowsAffected, nil
}
// applyAccessLogFilters applies filter conditions to the query
func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.AccessLogFilter) *gorm.DB {
if filter.Search != nil {

View File

@@ -270,6 +270,7 @@ type Store interface {
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error

View File

@@ -475,6 +475,21 @@ func (mr *MockStoreMockRecorder) DeleteNetworkRouter(ctx, accountID, routerID in
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteNetworkRouter", reflect.TypeOf((*MockStore)(nil).DeleteNetworkRouter), ctx, accountID, routerID)
}
// DeleteOldAccessLogs mocks base method.
func (m *MockStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOldAccessLogs", ctx, olderThan)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteOldAccessLogs indicates an expected call of DeleteOldAccessLogs.
func (mr *MockStoreMockRecorder) DeleteOldAccessLogs(ctx, olderThan interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAccessLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAccessLogs), ctx, olderThan)
}
// DeletePAT mocks base method.
func (m *MockStore) DeletePAT(ctx context.Context, userID, patID string) error {
m.ctrl.T.Helper()

View File

@@ -737,6 +737,14 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
if initiatorUserId != activity.SystemInitiator {
freshInitiator, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, initiatorUserId)
if err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to re-read initiator user in transaction: %w", err)
}
initiatorUser = freshInitiator
}
oldUser, isNewUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists)
if err != nil {
return false, nil, nil, nil, err
@@ -864,7 +872,10 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
return nil
}
// @todo double check these
if !initiatorUser.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only admins and owners can update users")
}
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}

View File

@@ -2031,3 +2031,87 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
t.Logf("Duplicate email error: %v", err)
})
}
func TestValidateUserUpdate_RejectsNonAdminInitiator(t *testing.T) {
groupsMap := map[string]*types.Group{}
initiator := &types.User{
Id: "initiator",
Role: types.UserRoleUser,
}
oldUser := &types.User{
Id: "target",
Role: types.UserRoleUser,
}
update := &types.User{
Id: "target",
Role: types.UserRoleOwner,
}
err := validateUserUpdate(groupsMap, initiator, oldUser, update)
require.Error(t, err, "regular user should not be able to promote to owner")
assert.Contains(t, err.Error(), "only admins and owners can update users")
}
func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) {
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), "account1", "owner1", "", "", "", false)
adminID := "admin1"
account.Users[adminID] = types.NewAdminUser(adminID)
targetID := "target1"
account.Users[targetID] = types.NewRegularUser(targetID, "", "")
require.NoError(t, s.SaveAccount(context.Background(), account))
demotedAdmin, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, adminID)
require.NoError(t, err)
demotedAdmin.Role = types.UserRoleUser
require.NoError(t, s.SaveUser(context.Background(), demotedAdmin))
staleInitiator := &types.User{
Id: adminID,
AccountID: account.Id,
Role: types.UserRoleAdmin,
}
permissionsManager := permissions.NewManager(s)
am := DefaultAccountManager{
Store: s,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsManager,
}
settings, err := s.GetAccountSettings(context.Background(), store.LockingStrengthNone, account.Id)
require.NoError(t, err)
groups, err := s.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
require.NoError(t, err)
groupsMap := make(map[string]*types.Group, len(groups))
for _, g := range groups {
groupsMap[g.ID] = g
}
update := &types.User{
Id: targetID,
Role: types.UserRoleAdmin,
}
err = s.ExecuteInTransaction(context.Background(), func(tx store.Store) error {
_, _, _, _, txErr := am.processUserUpdate(
context.Background(), tx, groupsMap, account.Id, adminID, staleInitiator, update, false, settings,
)
return txErr
})
require.Error(t, err, "processUserUpdate should reject stale initiator whose role was demoted")
assert.Contains(t, err.Error(), "only admins and owners can update users")
targetUser, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetID)
require.NoError(t, err)
assert.Equal(t, types.UserRoleUser, targetUser.Role)
}

View File

@@ -166,6 +166,18 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
// testAccessLogManager provides access log storage for testing.
type testAccessLogManager struct{}
func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
return 0, nil
}
func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
// noop
}
func (m *testAccessLogManager) StopPeriodicCleanup() {
// noop
}
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
return nil
}

View File

@@ -7409,6 +7409,20 @@ paths:
minimum: 1
maximum: 100
description: Number of items per page (max 100)
- in: query
name: sort_by
schema:
type: string
enum: [timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason]
default: timestamp
description: Field to sort by (url sorts by host then path)
- in: query
name: sort_order
schema:
type: string
enum: [asc, desc]
default: desc
description: Sort order (ascending or descending)
- in: query
name: search
schema: