mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy
This commit is contained in:
@@ -28,8 +28,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
@@ -1562,8 +1562,10 @@ func (e *Engine) receiveSignalEvents() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
start := time.Now()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
gotLock := time.Since(start)
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
@@ -1587,6 +1589,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
|
||||
@@ -351,6 +351,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
|
||||
logger.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
|
||||
// Allow time for route changes to be applied before sending
|
||||
// the DNS response (relevant on iOS where setTunnelNetworkSettings
|
||||
// is asynchronous).
|
||||
waitForRouteSettlement(logger)
|
||||
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
|
||||
}
|
||||
}
|
||||
|
||||
20
client/internal/routemanager/dnsinterceptor/handler_ios.go
Normal file
20
client/internal/routemanager/dnsinterceptor/handler_ios.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build ios
|
||||
|
||||
package dnsinterceptor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const routeSettleDelay = 500 * time.Millisecond
|
||||
|
||||
// waitForRouteSettlement introduces a short delay on iOS to allow
|
||||
// setTunnelNetworkSettings to apply route changes before the DNS
|
||||
// response reaches the application. Without this, the first request
|
||||
// to a newly resolved domain may bypass the tunnel.
|
||||
func waitForRouteSettlement(logger *log.Entry) {
|
||||
logger.Tracef("waiting %v for iOS route settlement", routeSettleDelay)
|
||||
time.Sleep(routeSettleDelay)
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !ios
|
||||
|
||||
package dnsinterceptor
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
func waitForRouteSettlement(_ *log.Entry) {
|
||||
// No-op on non-iOS platforms: route changes are applied synchronously by
|
||||
// the kernel, so no settlement delay is needed before the DNS response
|
||||
// reaches the application. The delay is only required on iOS where
|
||||
// setTunnelNetworkSettings applies routes asynchronously.
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
package txt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/levels"
|
||||
@@ -18,7 +16,7 @@ type TextFormatter struct {
|
||||
func NewTextFormatter() *TextFormatter {
|
||||
return &TextFormatter{
|
||||
levelDesc: levels.ValidLevelDesc,
|
||||
timestampFormat: time.RFC3339, // or RFC3339
|
||||
timestampFormat: "2006-01-02T15:04:05.000Z07:00",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,6 @@ func TestLogTextFormat(t *testing.T) {
|
||||
result, _ := formatter.Format(someEntry)
|
||||
|
||||
parsedString := string(result)
|
||||
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||
expectedString := "^2021-02-21T01:10:30.000Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||
assert.Regexp(t, expectedString, parsedString)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package accesslogs
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -11,15 +12,39 @@ const (
|
||||
DefaultPageSize = 50
|
||||
// MaxPageSize is the maximum number of records allowed per page
|
||||
MaxPageSize = 100
|
||||
|
||||
// Default sorting
|
||||
DefaultSortBy = "timestamp"
|
||||
DefaultSortOrder = "desc"
|
||||
)
|
||||
|
||||
// AccessLogFilter holds pagination and filtering parameters for access logs
|
||||
// Valid sortable fields mapped to their database column names or expressions
|
||||
// For multi-column sorts, columns are separated by comma (e.g., "host, path")
|
||||
var validSortFields = map[string]string{
|
||||
"timestamp": "timestamp",
|
||||
"url": "host, path", // Sort by host first, then path
|
||||
"host": "host",
|
||||
"path": "path",
|
||||
"method": "method",
|
||||
"status_code": "status_code",
|
||||
"duration": "duration",
|
||||
"source_ip": "location_connection_ip",
|
||||
"user_id": "user_id",
|
||||
"auth_method": "auth_method_used",
|
||||
"reason": "reason",
|
||||
}
|
||||
|
||||
// AccessLogFilter holds pagination, filtering, and sorting parameters for access logs
|
||||
type AccessLogFilter struct {
|
||||
// Page is the current page number (1-indexed)
|
||||
Page int
|
||||
// PageSize is the number of records per page
|
||||
PageSize int
|
||||
|
||||
// Sorting parameters
|
||||
SortBy string // Field to sort by: timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason
|
||||
SortOrder string // Sort order: asc or desc (default: desc)
|
||||
|
||||
// Filtering parameters
|
||||
Search *string // General search across log ID, host, path, source IP, and user fields
|
||||
SourceIP *string // Filter by source IP address
|
||||
@@ -35,13 +60,16 @@ type AccessLogFilter struct {
|
||||
EndDate *time.Time // Filter by timestamp <= end_date
|
||||
}
|
||||
|
||||
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
|
||||
// ParseFromRequest parses pagination, sorting, and filter parameters from HTTP request query parameters
|
||||
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||
queryParams := r.URL.Query()
|
||||
|
||||
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
|
||||
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
|
||||
|
||||
f.SortBy = parseSortField(queryParams.Get("sort_by"))
|
||||
f.SortOrder = parseSortOrder(queryParams.Get("sort_order"))
|
||||
|
||||
f.Search = parseOptionalString(queryParams.Get("search"))
|
||||
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
|
||||
f.Host = parseOptionalString(queryParams.Get("host"))
|
||||
@@ -107,3 +135,44 @@ func (f *AccessLogFilter) GetOffset() int {
|
||||
func (f *AccessLogFilter) GetLimit() int {
|
||||
return f.PageSize
|
||||
}
|
||||
|
||||
// GetSortColumn returns the validated database column name for sorting
|
||||
func (f *AccessLogFilter) GetSortColumn() string {
|
||||
if column, ok := validSortFields[f.SortBy]; ok {
|
||||
return column
|
||||
}
|
||||
return validSortFields[DefaultSortBy]
|
||||
}
|
||||
|
||||
// GetSortOrder returns the validated sort order (ASC or DESC)
|
||||
func (f *AccessLogFilter) GetSortOrder() string {
|
||||
if f.SortOrder == "asc" || f.SortOrder == "desc" {
|
||||
return f.SortOrder
|
||||
}
|
||||
return DefaultSortOrder
|
||||
}
|
||||
|
||||
// parseSortField validates and returns the sort field, defaulting if invalid
|
||||
func parseSortField(s string) string {
|
||||
if s == "" {
|
||||
return DefaultSortBy
|
||||
}
|
||||
// Check if the field is valid
|
||||
if _, ok := validSortFields[s]; ok {
|
||||
return s
|
||||
}
|
||||
return DefaultSortBy
|
||||
}
|
||||
|
||||
// parseSortOrder validates and returns the sort order, defaulting if invalid
|
||||
func parseSortOrder(s string) string {
|
||||
if s == "" {
|
||||
return DefaultSortOrder
|
||||
}
|
||||
// Normalize to lowercase
|
||||
s = strings.ToLower(s)
|
||||
if s == "asc" || s == "desc" {
|
||||
return s
|
||||
}
|
||||
return DefaultSortOrder
|
||||
}
|
||||
|
||||
@@ -361,6 +361,205 @@ func TestParseOptionalRFC3339(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_SortingDefaults(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, DefaultSortBy, filter.SortBy, "SortBy should default to timestamp")
|
||||
assert.Equal(t, DefaultSortOrder, filter.SortOrder, "SortOrder should default to desc")
|
||||
assert.Equal(t, "timestamp", filter.GetSortColumn(), "GetSortColumn should return timestamp")
|
||||
assert.Equal(t, "desc", filter.GetSortOrder(), "GetSortOrder should return desc")
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_ValidSortFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
expectedColumn string
|
||||
expectedSortByVal string
|
||||
}{
|
||||
{"timestamp", "timestamp", "timestamp", "timestamp"},
|
||||
{"url", "url", "host, path", "url"},
|
||||
{"host", "host", "host", "host"},
|
||||
{"path", "path", "path", "path"},
|
||||
{"method", "method", "method", "method"},
|
||||
{"status_code", "status_code", "status_code", "status_code"},
|
||||
{"duration", "duration", "duration", "duration"},
|
||||
{"source_ip", "source_ip", "location_connection_ip", "source_ip"},
|
||||
{"user_id", "user_id", "user_id", "user_id"},
|
||||
{"auth_method", "auth_method", "auth_method_used", "auth_method"},
|
||||
{"reason", "reason", "reason", "reason"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expectedSortByVal, filter.SortBy, "SortBy mismatch")
|
||||
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn(), "GetSortColumn mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_InvalidSortField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
expected string
|
||||
}{
|
||||
{"invalid field", "invalid_field", DefaultSortBy},
|
||||
{"empty field", "", DefaultSortBy},
|
||||
{"malicious input", "timestamp--DROP", DefaultSortBy},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
q := req.URL.Query()
|
||||
q.Set("sort_by", tt.sortBy)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expected, filter.SortBy, "Invalid sort field should default to timestamp")
|
||||
assert.Equal(t, validSortFields[DefaultSortBy], filter.GetSortColumn())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_SortOrder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortOrder string
|
||||
expected string
|
||||
}{
|
||||
{"ascending", "asc", "asc"},
|
||||
{"descending", "desc", "desc"},
|
||||
{"uppercase ASC", "ASC", "asc"},
|
||||
{"uppercase DESC", "DESC", "desc"},
|
||||
{"mixed case Asc", "Asc", "asc"},
|
||||
{"invalid order", "invalid", DefaultSortOrder},
|
||||
{"empty order", "", DefaultSortOrder},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_order="+tt.sortOrder, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expected, filter.GetSortOrder(), "GetSortOrder mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogFilter_CompleteSortingScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
sortOrder string
|
||||
expectedColumn string
|
||||
expectedOrder string
|
||||
}{
|
||||
{
|
||||
name: "sort by host ascending",
|
||||
sortBy: "host",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "host",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "sort by duration descending",
|
||||
sortBy: "duration",
|
||||
sortOrder: "desc",
|
||||
expectedColumn: "duration",
|
||||
expectedOrder: "desc",
|
||||
},
|
||||
{
|
||||
name: "sort by status_code ascending",
|
||||
sortBy: "status_code",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "status_code",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "invalid sort with valid order",
|
||||
sortBy: "invalid",
|
||||
sortOrder: "asc",
|
||||
expectedColumn: "timestamp",
|
||||
expectedOrder: "asc",
|
||||
},
|
||||
{
|
||||
name: "valid sort with invalid order",
|
||||
sortBy: "method",
|
||||
sortOrder: "invalid",
|
||||
expectedColumn: "method",
|
||||
expectedOrder: DefaultSortOrder,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy+"&sort_order="+tt.sortOrder, nil)
|
||||
|
||||
filter := &AccessLogFilter{}
|
||||
filter.ParseFromRequest(req)
|
||||
|
||||
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn())
|
||||
assert.Equal(t, tt.expectedOrder, filter.GetSortOrder())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"valid field", "host", "host"},
|
||||
{"empty string", "", DefaultSortBy},
|
||||
{"invalid field", "invalid", DefaultSortBy},
|
||||
{"malicious input", "timestamp--DROP", DefaultSortBy},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseSortField(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortOrder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"asc lowercase", "asc", "asc"},
|
||||
{"desc lowercase", "desc", "desc"},
|
||||
{"ASC uppercase", "ASC", "asc"},
|
||||
{"DESC uppercase", "DESC", "desc"},
|
||||
{"invalid", "invalid", DefaultSortOrder},
|
||||
{"empty", "", DefaultSortOrder},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseSortOrder(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating pointers
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
|
||||
@@ -7,4 +7,7 @@ import (
|
||||
type Manager interface {
|
||||
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
|
||||
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
|
||||
CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error)
|
||||
StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int)
|
||||
StopPeriodicCleanup()
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package manager
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -19,6 +20,7 @@ type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
geo geolocation.Geolocation
|
||||
cleanupCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||
@@ -78,6 +80,74 @@ func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID st
|
||||
return logs, totalCount, nil
|
||||
}
|
||||
|
||||
// CleanupOldAccessLogs deletes access logs older than the specified retention period
|
||||
func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
|
||||
if retentionDays <= 0 {
|
||||
log.WithContext(ctx).Debug("access log cleanup skipped: retention days is 0 or negative")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
cutoffTime := time.Now().AddDate(0, 0, -retentionDays)
|
||||
deletedCount, err := m.store.DeleteOldAccessLogs(ctx, cutoffTime)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup old access logs: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
log.WithContext(ctx).Infof("cleaned up %d access logs older than %d days", deletedCount, retentionDays)
|
||||
}
|
||||
|
||||
return deletedCount, nil
|
||||
}
|
||||
|
||||
// StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs
|
||||
func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
|
||||
if retentionDays <= 0 {
|
||||
log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative")
|
||||
return
|
||||
}
|
||||
|
||||
if cleanupIntervalHours <= 0 {
|
||||
cleanupIntervalHours = 24
|
||||
}
|
||||
|
||||
cleanupCtx, cancel := context.WithCancel(ctx)
|
||||
m.cleanupCancel = cancel
|
||||
|
||||
cleanupInterval := time.Duration(cleanupIntervalHours) * time.Hour
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run cleanup immediately on startup
|
||||
log.WithContext(cleanupCtx).Infof("starting access log cleanup routine (retention: %d days, interval: %d hours)", retentionDays, cleanupIntervalHours)
|
||||
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
|
||||
log.WithContext(cleanupCtx).Errorf("initial access log cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cleanupCtx.Done():
|
||||
log.WithContext(cleanupCtx).Info("stopping access log cleanup routine")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
|
||||
log.WithContext(cleanupCtx).Errorf("periodic access log cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StopPeriodicCleanup stops the periodic cleanup routine
|
||||
func (m *managerImpl) StopPeriodicCleanup() {
|
||||
if m.cleanupCancel != nil {
|
||||
m.cleanupCancel()
|
||||
}
|
||||
}
|
||||
|
||||
// resolveUserFilters converts user email/name filters to user ID filter
|
||||
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
||||
if filter.UserEmail == nil && filter.UserName == nil {
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
func TestCleanupOldAccessLogs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
retentionDays int
|
||||
setupMock func(*store.MockStore)
|
||||
expectedCount int64
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "cleanup logs older than retention period",
|
||||
retentionDays: 30,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
expectedCutoff := time.Now().AddDate(0, 0, -30)
|
||||
timeDiff := olderThan.Sub(expectedCutoff)
|
||||
if timeDiff.Abs() > time.Second {
|
||||
t.Errorf("cutoff time not as expected: got %v, want ~%v", olderThan, expectedCutoff)
|
||||
}
|
||||
return 5, nil
|
||||
})
|
||||
},
|
||||
expectedCount: 5,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "no logs to cleanup",
|
||||
retentionDays: 30,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(0), nil)
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "zero retention days skips cleanup",
|
||||
retentionDays: 0,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
// No expectations - DeleteOldAccessLogs should not be called
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "negative retention days skips cleanup",
|
||||
retentionDays: -10,
|
||||
setupMock: func(mockStore *store.MockStore) {
|
||||
// No expectations - DeleteOldAccessLogs should not be called
|
||||
},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
deletedCount, err := manager.CleanupOldAccessLogs(ctx, tt.retentionDays)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.expectedCount, deletedCount, "unexpected number of deleted logs")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupWithExactBoundary(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
expectedCutoff := time.Now().AddDate(0, 0, -30)
|
||||
timeDiff := olderThan.Sub(expectedCutoff)
|
||||
assert.Less(t, timeDiff.Abs(), time.Second, "cutoff time should be close to expected value")
|
||||
return 1, nil
|
||||
})
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
deletedCount, err := manager.CleanupOldAccessLogs(ctx, 30)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), deletedCount)
|
||||
}
|
||||
|
||||
func TestStartPeriodicCleanup(t *testing.T) {
|
||||
t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
// No expectations - cleanup should not run
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 0, 1)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// If DeleteOldAccessLogs was called, the test will fail due to unexpected call
|
||||
})
|
||||
|
||||
t.Run("periodic cleanup runs immediately on start", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(2), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Expectations verified by gomock on defer ctrl.Finish()
|
||||
})
|
||||
|
||||
t.Run("periodic cleanup stops on context cancel", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(1), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
})
|
||||
|
||||
t.Run("cleanup interval defaults to 24 hours when invalid", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(0), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 0)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
})
|
||||
|
||||
t.Run("cleanup interval uses configured hours", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(3), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 12)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopPeriodicCleanup(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
mockStore.EXPECT().
|
||||
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||
Return(int64(1), nil).
|
||||
Times(1)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.StopPeriodicCleanup()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Expectations verified by gomock - would fail if more than 1 call happened
|
||||
}
|
||||
|
||||
func TestStopPeriodicCleanup_NotStarted(t *testing.T) {
|
||||
manager := &managerImpl{}
|
||||
|
||||
// Should not panic if cleanup was never started
|
||||
manager.StopPeriodicCleanup()
|
||||
}
|
||||
@@ -200,6 +200,11 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||
return Create(s, func() accesslogs.Manager {
|
||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||
accessLogManager.StartPeriodicCleanup(
|
||||
context.Background(),
|
||||
s.Config.ReverseProxy.AccessLogRetentionDays,
|
||||
s.Config.ReverseProxy.AccessLogCleanupIntervalHours,
|
||||
)
|
||||
return accessLogManager
|
||||
})
|
||||
}
|
||||
|
||||
@@ -200,4 +200,13 @@ type ReverseProxy struct {
|
||||
// request headers if the peer's address falls within one of these
|
||||
// trusted IP prefixes.
|
||||
TrustedPeers []netip.Prefix
|
||||
|
||||
// AccessLogRetentionDays specifies the number of days to retain access logs.
|
||||
// Logs older than this duration will be automatically deleted during cleanup.
|
||||
// A value of 0 or negative means logs are kept indefinitely (no cleanup).
|
||||
AccessLogRetentionDays int
|
||||
|
||||
// AccessLogCleanupIntervalHours specifies how often (in hours) to run the cleanup routine.
|
||||
// Defaults to 24 hours if not set or set to 0.
|
||||
AccessLogCleanupIntervalHours int
|
||||
}
|
||||
|
||||
@@ -157,6 +157,18 @@ type testSetup struct {
|
||||
// testAccessLogManager is a minimal mock for accesslogs.Manager.
|
||||
type testAccessLogManager struct{}
|
||||
|
||||
func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) StopPeriodicCleanup() {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5083,8 +5083,20 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
|
||||
query = s.applyAccessLogFilters(query, filter)
|
||||
|
||||
sortColumns := filter.GetSortColumn()
|
||||
sortOrder := strings.ToUpper(filter.GetSortOrder())
|
||||
|
||||
var orderClauses []string
|
||||
for _, col := range strings.Split(sortColumns, ",") {
|
||||
col = strings.TrimSpace(col)
|
||||
if col != "" {
|
||||
orderClauses = append(orderClauses, col+" "+sortOrder)
|
||||
}
|
||||
}
|
||||
orderClause := strings.Join(orderClauses, ", ")
|
||||
|
||||
query = query.
|
||||
Order("timestamp DESC").
|
||||
Order(orderClause).
|
||||
Limit(filter.GetLimit()).
|
||||
Offset(filter.GetOffset())
|
||||
|
||||
@@ -5101,6 +5113,20 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
return logs, totalCount, nil
|
||||
}
|
||||
|
||||
// DeleteOldAccessLogs deletes all access logs older than the specified time
|
||||
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("timestamp < ?", olderThan).
|
||||
Delete(&accesslogs.AccessLogEntry{})
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete old access logs: %v", result.Error)
|
||||
return 0, status.Errorf(status.Internal, "failed to delete old access logs")
|
||||
}
|
||||
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// applyAccessLogFilters applies filter conditions to the query
|
||||
func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.AccessLogFilter) *gorm.DB {
|
||||
if filter.Search != nil {
|
||||
|
||||
@@ -270,6 +270,7 @@ type Store interface {
|
||||
|
||||
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
|
||||
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
||||
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
|
||||
|
||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||
|
||||
@@ -475,6 +475,21 @@ func (mr *MockStoreMockRecorder) DeleteNetworkRouter(ctx, accountID, routerID in
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteNetworkRouter", reflect.TypeOf((*MockStore)(nil).DeleteNetworkRouter), ctx, accountID, routerID)
|
||||
}
|
||||
|
||||
// DeleteOldAccessLogs mocks base method.
|
||||
func (m *MockStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldAccessLogs", ctx, olderThan)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldAccessLogs indicates an expected call of DeleteOldAccessLogs.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldAccessLogs(ctx, olderThan interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAccessLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAccessLogs), ctx, olderThan)
|
||||
}
|
||||
|
||||
// DeletePAT mocks base method.
|
||||
func (m *MockStore) DeletePAT(ctx context.Context, userID, patID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -737,6 +737,14 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
if initiatorUserId != activity.SystemInitiator {
|
||||
freshInitiator, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, initiatorUserId)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("failed to re-read initiator user in transaction: %w", err)
|
||||
}
|
||||
initiatorUser = freshInitiator
|
||||
}
|
||||
|
||||
oldUser, isNewUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, err
|
||||
@@ -864,7 +872,10 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
|
||||
return nil
|
||||
}
|
||||
|
||||
// @todo double check these
|
||||
if !initiatorUser.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only admins and owners can update users")
|
||||
}
|
||||
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
|
||||
@@ -2031,3 +2031,87 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
|
||||
t.Logf("Duplicate email error: %v", err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateUserUpdate_RejectsNonAdminInitiator(t *testing.T) {
|
||||
groupsMap := map[string]*types.Group{}
|
||||
|
||||
initiator := &types.User{
|
||||
Id: "initiator",
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
oldUser := &types.User{
|
||||
Id: "target",
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
update := &types.User{
|
||||
Id: "target",
|
||||
Role: types.UserRoleOwner,
|
||||
}
|
||||
|
||||
err := validateUserUpdate(groupsMap, initiator, oldUser, update)
|
||||
require.Error(t, err, "regular user should not be able to promote to owner")
|
||||
assert.Contains(t, err.Error(), "only admins and owners can update users")
|
||||
}
|
||||
|
||||
func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) {
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account1", "owner1", "", "", "", false)
|
||||
|
||||
adminID := "admin1"
|
||||
account.Users[adminID] = types.NewAdminUser(adminID)
|
||||
|
||||
targetID := "target1"
|
||||
account.Users[targetID] = types.NewRegularUser(targetID, "", "")
|
||||
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account))
|
||||
|
||||
demotedAdmin, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, adminID)
|
||||
require.NoError(t, err)
|
||||
demotedAdmin.Role = types.UserRoleUser
|
||||
require.NoError(t, s.SaveUser(context.Background(), demotedAdmin))
|
||||
|
||||
staleInitiator := &types.User{
|
||||
Id: adminID,
|
||||
AccountID: account.Id,
|
||||
Role: types.UserRoleAdmin,
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := DefaultAccountManager{
|
||||
Store: s,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
|
||||
settings, err := s.GetAccountSettings(context.Background(), store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := s.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
groupsMap := make(map[string]*types.Group, len(groups))
|
||||
for _, g := range groups {
|
||||
groupsMap[g.ID] = g
|
||||
}
|
||||
|
||||
update := &types.User{
|
||||
Id: targetID,
|
||||
Role: types.UserRoleAdmin,
|
||||
}
|
||||
|
||||
err = s.ExecuteInTransaction(context.Background(), func(tx store.Store) error {
|
||||
_, _, _, _, txErr := am.processUserUpdate(
|
||||
context.Background(), tx, groupsMap, account.Id, adminID, staleInitiator, update, false, settings,
|
||||
)
|
||||
return txErr
|
||||
})
|
||||
|
||||
require.Error(t, err, "processUserUpdate should reject stale initiator whose role was demoted")
|
||||
assert.Contains(t, err.Error(), "only admins and owners can update users")
|
||||
|
||||
targetUser, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, types.UserRoleUser, targetUser.Role)
|
||||
}
|
||||
|
||||
@@ -166,6 +166,18 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
// testAccessLogManager provides access log storage for testing.
|
||||
type testAccessLogManager struct{}
|
||||
|
||||
func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) StopPeriodicCleanup() {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7409,6 +7409,20 @@ paths:
|
||||
minimum: 1
|
||||
maximum: 100
|
||||
description: Number of items per page (max 100)
|
||||
- in: query
|
||||
name: sort_by
|
||||
schema:
|
||||
type: string
|
||||
enum: [timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason]
|
||||
default: timestamp
|
||||
description: Field to sort by (url sorts by host then path)
|
||||
- in: query
|
||||
name: sort_order
|
||||
schema:
|
||||
type: string
|
||||
enum: [asc, desc]
|
||||
default: desc
|
||||
description: Sort order (ascending or descending)
|
||||
- in: query
|
||||
name: search
|
||||
schema:
|
||||
|
||||
Reference in New Issue
Block a user