diff --git a/management/internals/modules/reverseproxy/accesslogs/filter.go b/management/internals/modules/reverseproxy/accesslogs/filter.go index 17edc4b1c..f4b0a2048 100644 --- a/management/internals/modules/reverseproxy/accesslogs/filter.go +++ b/management/internals/modules/reverseproxy/accesslogs/filter.go @@ -39,78 +39,63 @@ type AccessLogFilter struct { func (f *AccessLogFilter) ParseFromRequest(r *http.Request) { queryParams := r.URL.Query() - f.Page = 1 - if pageStr := queryParams.Get("page"); pageStr != "" { - if page, err := strconv.Atoi(pageStr); err == nil && page > 0 { - f.Page = page - } - } + f.Page = parsePositiveInt(queryParams.Get("page"), 1) + f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize) - f.PageSize = DefaultPageSize - if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" { - if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 { - f.PageSize = pageSize - if f.PageSize > MaxPageSize { - f.PageSize = MaxPageSize - } - } - } + f.Search = parseOptionalString(queryParams.Get("search")) + f.SourceIP = parseOptionalString(queryParams.Get("source_ip")) + f.Host = parseOptionalString(queryParams.Get("host")) + f.Path = parseOptionalString(queryParams.Get("path")) + f.UserID = parseOptionalString(queryParams.Get("user_id")) + f.UserEmail = parseOptionalString(queryParams.Get("user_email")) + f.UserName = parseOptionalString(queryParams.Get("user_name")) + f.Method = parseOptionalString(queryParams.Get("method")) + f.Status = parseOptionalString(queryParams.Get("status")) + f.StatusCode = parseOptionalInt(queryParams.Get("status_code")) + f.StartDate = parseOptionalRFC3339(queryParams.Get("start_date")) + f.EndDate = parseOptionalRFC3339(queryParams.Get("end_date")) +} - if search := queryParams.Get("search"); search != "" { - f.Search = &search +// parsePositiveInt parses a positive integer from a string, returning defaultValue if invalid +func parsePositiveInt(s string, defaultValue int) int { + if s == "" { + return defaultValue } + if val, err := strconv.Atoi(s); err == nil && val > 0 { + return val + } + return defaultValue +} - if sourceIP := queryParams.Get("source_ip"); sourceIP != "" { - f.SourceIP = &sourceIP +// parseOptionalString returns a pointer to the string if non-empty, otherwise nil +func parseOptionalString(s string) *string { + if s == "" { + return nil } + return &s +} - if host := queryParams.Get("host"); host != "" { - f.Host = &host +// parseOptionalInt parses an optional positive integer from a string +func parseOptionalInt(s string) *int { + if s == "" { + return nil } + if val, err := strconv.Atoi(s); err == nil && val > 0 { + v := val + return &v + } + return nil +} - if path := queryParams.Get("path"); path != "" { - f.Path = &path +// parseOptionalRFC3339 parses an optional RFC3339 timestamp from a string +func parseOptionalRFC3339(s string) *time.Time { + if s == "" { + return nil } - - if userID := queryParams.Get("user_id"); userID != "" { - f.UserID = &userID - } - - if userEmail := queryParams.Get("user_email"); userEmail != "" { - f.UserEmail = &userEmail - } - - if userName := queryParams.Get("user_name"); userName != "" { - f.UserName = &userName - } - - if method := queryParams.Get("method"); method != "" { - f.Method = &method - } - - if status := queryParams.Get("status"); status != "" { - f.Status = &status - } - - if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" { - if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 { - f.StatusCode = &statusCode - } - } - - if startDate := queryParams.Get("start_date"); startDate != "" { - parsedStartDate, err := time.Parse(time.RFC3339, startDate) - if err == nil { - f.StartDate = &parsedStartDate - } - } - - if endDate := queryParams.Get("end_date"); endDate != "" { - parsedEndDate, err := time.Parse(time.RFC3339, endDate) - if err == nil { - f.EndDate = &parsedEndDate - } + if t, err := time.Parse(time.RFC3339, s); err == nil { + return &t } + return nil } // GetOffset calculates the database offset for pagination diff --git a/management/internals/modules/reverseproxy/accesslogs/filter_test.go b/management/internals/modules/reverseproxy/accesslogs/filter_test.go index 4ca4508bc..5d48ea9d2 100644 --- a/management/internals/modules/reverseproxy/accesslogs/filter_test.go +++ b/management/internals/modules/reverseproxy/accesslogs/filter_test.go @@ -4,8 +4,10 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAccessLogFilter_ParseFromRequest(t *testing.T) { @@ -159,3 +161,211 @@ func TestAccessLogFilter_GetLimit(t *testing.T) { limit := filter.GetLimit() assert.Equal(t, 25, limit, "GetLimit should return PageSize") } + +func TestAccessLogFilter_ParseFromRequest_FilterParams(t *testing.T) { + startDate := "2024-01-15T10:30:00Z" + endDate := "2024-01-16T15:45:00Z" + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + q := req.URL.Query() + q.Set("search", "test query") + q.Set("source_ip", "192.168.1.1") + q.Set("host", "example.com") + q.Set("path", "/api/users") + q.Set("user_id", "user123") + q.Set("user_email", "user@example.com") + q.Set("user_name", "John Doe") + q.Set("method", "GET") + q.Set("status", "success") + q.Set("status_code", "200") + q.Set("start_date", startDate) + q.Set("end_date", endDate) + req.URL.RawQuery = q.Encode() + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + require.NotNil(t, filter.Search) + assert.Equal(t, "test query", *filter.Search) + + require.NotNil(t, filter.SourceIP) + assert.Equal(t, "192.168.1.1", *filter.SourceIP) + + require.NotNil(t, filter.Host) + assert.Equal(t, "example.com", *filter.Host) + + require.NotNil(t, filter.Path) + assert.Equal(t, "/api/users", *filter.Path) + + require.NotNil(t, filter.UserID) + assert.Equal(t, "user123", *filter.UserID) + + require.NotNil(t, filter.UserEmail) + assert.Equal(t, "user@example.com", *filter.UserEmail) + + require.NotNil(t, filter.UserName) + assert.Equal(t, "John Doe", *filter.UserName) + + require.NotNil(t, filter.Method) + assert.Equal(t, "GET", *filter.Method) + + require.NotNil(t, filter.Status) + assert.Equal(t, "success", *filter.Status) + + require.NotNil(t, filter.StatusCode) + assert.Equal(t, 200, *filter.StatusCode) + + require.NotNil(t, filter.StartDate) + expectedStart, _ := time.Parse(time.RFC3339, startDate) + assert.Equal(t, expectedStart, *filter.StartDate) + + require.NotNil(t, filter.EndDate) + expectedEnd, _ := time.Parse(time.RFC3339, endDate) + assert.Equal(t, expectedEnd, *filter.EndDate) +} + +func TestAccessLogFilter_ParseFromRequest_EmptyFilters(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Nil(t, filter.Search) + assert.Nil(t, filter.SourceIP) + assert.Nil(t, filter.Host) + assert.Nil(t, filter.Path) + assert.Nil(t, filter.UserID) + assert.Nil(t, filter.UserEmail) + assert.Nil(t, filter.UserName) + assert.Nil(t, filter.Method) + assert.Nil(t, filter.Status) + assert.Nil(t, filter.StatusCode) + assert.Nil(t, filter.StartDate) + assert.Nil(t, filter.EndDate) +} + +func TestAccessLogFilter_ParseFromRequest_InvalidFilters(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + q := req.URL.Query() + q.Set("status_code", "invalid") + q.Set("start_date", "not-a-date") + q.Set("end_date", "2024-99-99") + req.URL.RawQuery = q.Encode() + + filter := &AccessLogFilter{} + filter.ParseFromRequest(req) + + assert.Nil(t, filter.StatusCode, "invalid status_code should be nil") + assert.Nil(t, filter.StartDate, "invalid start_date should be nil") + assert.Nil(t, filter.EndDate, "invalid end_date should be nil") +} + +func TestParsePositiveInt(t *testing.T) { + tests := []struct { + name string + input string + defaultValue int + expected int + }{ + {"empty string", "", 10, 10}, + {"valid positive int", "25", 10, 25}, + {"zero", "0", 10, 10}, + {"negative", "-5", 10, 10}, + {"invalid string", "abc", 10, 10}, + {"float", "3.14", 10, 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parsePositiveInt(tt.input, tt.defaultValue) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseOptionalString(t *testing.T) { + tests := []struct { + name string + input string + expected *string + }{ + {"empty string", "", nil}, + {"valid string", "hello", strPtr("hello")}, + {"whitespace", " ", strPtr(" ")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseOptionalString(tt.input) + if tt.expected == nil { + assert.Nil(t, result) + } else { + require.NotNil(t, result) + assert.Equal(t, *tt.expected, *result) + } + }) + } +} + +func TestParseOptionalInt(t *testing.T) { + tests := []struct { + name string + input string + expected *int + }{ + {"empty string", "", nil}, + {"valid positive int", "42", intPtr(42)}, + {"zero", "0", nil}, + {"negative", "-10", nil}, + {"invalid string", "abc", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseOptionalInt(tt.input) + if tt.expected == nil { + assert.Nil(t, result) + } else { + require.NotNil(t, result) + assert.Equal(t, *tt.expected, *result) + } + }) + } +} + +func TestParseOptionalRFC3339(t *testing.T) { + validDate := "2024-01-15T10:30:00Z" + expectedTime, _ := time.Parse(time.RFC3339, validDate) + + tests := []struct { + name string + input string + expected *time.Time + }{ + {"empty string", "", nil}, + {"valid RFC3339", validDate, &expectedTime}, + {"invalid format", "2024-01-15", nil}, + {"invalid date", "not-a-date", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseOptionalRFC3339(tt.input) + if tt.expected == nil { + assert.Nil(t, result) + } else { + require.NotNil(t, result) + assert.Equal(t, *tt.expected, *result) + } + }) + } +} + +// Helper functions for creating pointers +func strPtr(s string) *string { + return &s +} + +func intPtr(i int) *int { + return &i +}