diff --git a/management/internals/modules/reverseproxy/accesslogs/filter.go b/management/internals/modules/reverseproxy/accesslogs/filter.go index f4b0a2048..a1fa28312 100644 --- a/management/internals/modules/reverseproxy/accesslogs/filter.go +++ b/management/internals/modules/reverseproxy/accesslogs/filter.go @@ -3,6 +3,7 @@ package accesslogs import ( "net/http" "strconv" + "strings" "time" ) @@ -11,15 +12,39 @@ const ( DefaultPageSize = 50 // MaxPageSize is the maximum number of records allowed per page MaxPageSize = 100 + + // Default sorting + DefaultSortBy = "timestamp" + DefaultSortOrder = "desc" ) -// AccessLogFilter holds pagination and filtering parameters for access logs +// Valid sortable fields mapped to their database column names or expressions +// For multi-column sorts, columns are separated by comma (e.g., "host, path") +var validSortFields = map[string]string{ + "timestamp": "timestamp", + "url": "host, path", // Sort by host first, then path + "host": "host", + "path": "path", + "method": "method", + "status_code": "status_code", + "duration": "duration", + "source_ip": "location_connection_ip", + "user_id": "user_id", + "auth_method": "auth_method_used", + "reason": "reason", +} + +// AccessLogFilter holds pagination, filtering, and sorting parameters for access logs type AccessLogFilter struct { // Page is the current page number (1-indexed) Page int // PageSize is the number of records per page PageSize int + // Sorting parameters + SortBy string // Field to sort by: timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason + SortOrder string // Sort order: asc or desc (default: desc) + // Filtering parameters Search *string // General search across log ID, host, path, source IP, and user fields SourceIP *string // Filter by source IP address @@ -35,13 +60,16 @@ type AccessLogFilter struct { EndDate *time.Time // Filter by timestamp <= end_date } -// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters +// ParseFromRequest parses pagination, sorting, and filter parameters from HTTP request query parameters func (f *AccessLogFilter) ParseFromRequest(r *http.Request) { queryParams := r.URL.Query() f.Page = parsePositiveInt(queryParams.Get("page"), 1) f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize) + f.SortBy = parseSortField(queryParams.Get("sort_by")) + f.SortOrder = parseSortOrder(queryParams.Get("sort_order")) + f.Search = parseOptionalString(queryParams.Get("search")) f.SourceIP = parseOptionalString(queryParams.Get("source_ip")) f.Host = parseOptionalString(queryParams.Get("host")) @@ -107,3 +135,44 @@ func (f *AccessLogFilter) GetOffset() int { func (f *AccessLogFilter) GetLimit() int { return f.PageSize } + +// GetSortColumn returns the validated database column name for sorting +func (f *AccessLogFilter) GetSortColumn() string { + if column, ok := validSortFields[f.SortBy]; ok { + return column + } + return validSortFields[DefaultSortBy] +} + +// GetSortOrder returns the validated sort order (ASC or DESC) +func (f *AccessLogFilter) GetSortOrder() string { + if f.SortOrder == "asc" || f.SortOrder == "desc" { + return f.SortOrder + } + return DefaultSortOrder +} + +// parseSortField validates and returns the sort field, defaulting if invalid +func parseSortField(s string) string { + if s == "" { + return DefaultSortBy + } + // Check if the field is valid + if _, ok := validSortFields[s]; ok { + return s + } + return DefaultSortBy +} + +// parseSortOrder validates and returns the sort order, defaulting if invalid +func parseSortOrder(s string) string { + if s == "" { + return DefaultSortOrder + } + // Normalize to lowercase + s = strings.ToLower(s) + if s == "asc" || s == "desc" { + return s + } + return DefaultSortOrder +} diff --git a/management/internals/modules/reverseproxy/accesslogs/filter_test.go b/management/internals/modules/reverseproxy/accesslogs/filter_test.go index 5d48ea9d2..ea1fce54b 100644 --- a/management/internals/modules/reverseproxy/accesslogs/filter_test.go +++ b/management/internals/modules/reverseproxy/accesslogs/filter_test.go @@ -361,6 +361,205 @@ func TestParseOptionalRFC3339(t *testing.T) { } } +func TestAccessLogFilter_SortingDefaults(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, DefaultSortBy, filter.SortBy, "SortBy should default to timestamp") + assert.Equal(t, DefaultSortOrder, filter.SortOrder, "SortOrder should default to desc") + assert.Equal(t, "timestamp", filter.GetSortColumn(), "GetSortColumn should return timestamp") + assert.Equal(t, "desc", filter.GetSortOrder(), "GetSortOrder should return desc") +} + +func TestAccessLogFilter_ValidSortFields(t *testing.T) { + tests := []struct { + name string + sortBy string + expectedColumn string + expectedSortByVal string + }{ + {"timestamp", "timestamp", "timestamp", "timestamp"}, + {"url", "url", "host, path", "url"}, + {"host", "host", "host", "host"}, + {"path", "path", "path", "path"}, + {"method", "method", "method", "method"}, + {"status_code", "status_code", "status_code", "status_code"}, + {"duration", "duration", "duration", "duration"}, + {"source_ip", "source_ip", "location_connection_ip", "source_ip"}, + {"user_id", "user_id", "user_id", "user_id"}, + {"auth_method", "auth_method", "auth_method_used", "auth_method"}, + {"reason", "reason", "reason", "reason"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy, nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, tt.expectedSortByVal, filter.SortBy, "SortBy mismatch") + assert.Equal(t, tt.expectedColumn, filter.GetSortColumn(), "GetSortColumn mismatch") + }) + } +} + +func TestAccessLogFilter_InvalidSortField(t *testing.T) { + tests := []struct { + name string + sortBy string + expected string + }{ + {"invalid field", "invalid_field", DefaultSortBy}, + {"empty field", "", DefaultSortBy}, + {"malicious input", "timestamp--DROP", DefaultSortBy}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + q := req.URL.Query() + q.Set("sort_by", tt.sortBy) + req.URL.RawQuery = q.Encode() + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, tt.expected, filter.SortBy, "Invalid sort field should default to timestamp") + assert.Equal(t, validSortFields[DefaultSortBy], filter.GetSortColumn()) + }) + } +} + +func TestAccessLogFilter_SortOrder(t *testing.T) { + tests := []struct { + name string + sortOrder string + expected string + }{ + {"ascending", "asc", "asc"}, + {"descending", "desc", "desc"}, + {"uppercase ASC", "ASC", "asc"}, + {"uppercase DESC", "DESC", "desc"}, + {"mixed case Asc", "Asc", "asc"}, + {"invalid order", "invalid", DefaultSortOrder}, + {"empty order", "", DefaultSortOrder}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test?sort_order="+tt.sortOrder, nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, tt.expected, filter.GetSortOrder(), "GetSortOrder mismatch") + }) + } +} + +func TestAccessLogFilter_CompleteSortingScenarios(t *testing.T) { + tests := []struct { + name string + sortBy string + sortOrder string + expectedColumn string + expectedOrder string + }{ + { + name: "sort by host ascending", + sortBy: "host", + sortOrder: "asc", + expectedColumn: "host", + expectedOrder: "asc", + }, + { + name: "sort by duration descending", + sortBy: "duration", + sortOrder: "desc", + expectedColumn: "duration", + expectedOrder: "desc", + }, + { + name: "sort by status_code ascending", + sortBy: "status_code", + sortOrder: "asc", + expectedColumn: "status_code", + expectedOrder: "asc", + }, + { + name: "invalid sort with valid order", + sortBy: "invalid", + sortOrder: "asc", + expectedColumn: "timestamp", + expectedOrder: "asc", + }, + { + name: "valid sort with invalid order", + sortBy: "method", + sortOrder: "invalid", + expectedColumn: "method", + expectedOrder: DefaultSortOrder, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy+"&sort_order="+tt.sortOrder, nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Equal(t, tt.expectedColumn, filter.GetSortColumn()) + assert.Equal(t, tt.expectedOrder, filter.GetSortOrder()) + }) + } +} + +func TestParseSortField(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"valid field", "host", "host"}, + {"empty string", "", DefaultSortBy}, + {"invalid field", "invalid", DefaultSortBy}, + {"malicious input", "timestamp--DROP", DefaultSortBy}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseSortField(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseSortOrder(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"asc lowercase", "asc", "asc"}, + {"desc lowercase", "desc", "desc"}, + {"ASC uppercase", "ASC", "asc"}, + {"DESC uppercase", "DESC", "desc"}, + {"invalid", "invalid", DefaultSortOrder}, + {"empty", "", DefaultSortOrder}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseSortOrder(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + // Helper functions for creating pointers func strPtr(s string) *string { return &s diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index e528cb4fb..018e54810 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5082,8 +5082,20 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin query = s.applyAccessLogFilters(query, filter) + sortColumns := filter.GetSortColumn() + sortOrder := strings.ToUpper(filter.GetSortOrder()) + + var orderClauses []string + for _, col := range strings.Split(sortColumns, ",") { + col = strings.TrimSpace(col) + if col != "" { + orderClauses = append(orderClauses, col+" "+sortOrder) + } + } + orderClause := strings.Join(orderClauses, ", ") + query = query. - Order("timestamp DESC"). + Order(orderClause). Limit(filter.GetLimit()). Offset(filter.GetOffset()) diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 1f4a163e5..b0ce1b5cc 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -7409,6 +7409,20 @@ paths: minimum: 1 maximum: 100 description: Number of items per page (max 100) + - in: query + name: sort_by + schema: + type: string + enum: [timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason] + default: timestamp + description: Field to sort by (url sorts by host then path) + - in: query + name: sort_order + schema: + type: string + enum: [asc, desc] + default: desc + description: Sort order (ascending or descending) - in: query name: search schema: