mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 09:46:40 +00:00
refactor access log filter
This commit is contained in:
@@ -39,78 +39,63 @@ type AccessLogFilter struct {
|
|||||||
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||||
queryParams := r.URL.Query()
|
queryParams := r.URL.Query()
|
||||||
|
|
||||||
f.Page = 1
|
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
|
||||||
if pageStr := queryParams.Get("page"); pageStr != "" {
|
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
|
||||||
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
|
|
||||||
f.Page = page
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
f.PageSize = DefaultPageSize
|
f.Search = parseOptionalString(queryParams.Get("search"))
|
||||||
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
|
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
|
||||||
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
|
f.Host = parseOptionalString(queryParams.Get("host"))
|
||||||
f.PageSize = pageSize
|
f.Path = parseOptionalString(queryParams.Get("path"))
|
||||||
if f.PageSize > MaxPageSize {
|
f.UserID = parseOptionalString(queryParams.Get("user_id"))
|
||||||
f.PageSize = MaxPageSize
|
f.UserEmail = parseOptionalString(queryParams.Get("user_email"))
|
||||||
}
|
f.UserName = parseOptionalString(queryParams.Get("user_name"))
|
||||||
}
|
f.Method = parseOptionalString(queryParams.Get("method"))
|
||||||
}
|
f.Status = parseOptionalString(queryParams.Get("status"))
|
||||||
|
f.StatusCode = parseOptionalInt(queryParams.Get("status_code"))
|
||||||
|
f.StartDate = parseOptionalRFC3339(queryParams.Get("start_date"))
|
||||||
|
f.EndDate = parseOptionalRFC3339(queryParams.Get("end_date"))
|
||||||
|
}
|
||||||
|
|
||||||
if search := queryParams.Get("search"); search != "" {
|
// parsePositiveInt parses a positive integer from a string, returning defaultValue if invalid
|
||||||
f.Search = &search
|
func parsePositiveInt(s string, defaultValue int) int {
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue
|
||||||
}
|
}
|
||||||
|
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
|
// parseOptionalString returns a pointer to the string if non-empty, otherwise nil
|
||||||
f.SourceIP = &sourceIP
|
func parseOptionalString(s string) *string {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
if host := queryParams.Get("host"); host != "" {
|
// parseOptionalInt parses an optional positive integer from a string
|
||||||
f.Host = &host
|
func parseOptionalInt(s string) *int {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||||
|
v := val
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if path := queryParams.Get("path"); path != "" {
|
// parseOptionalRFC3339 parses an optional RFC3339 timestamp from a string
|
||||||
f.Path = &path
|
func parseOptionalRFC3339(s string) *time.Time {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||||
if userID := queryParams.Get("user_id"); userID != "" {
|
return &t
|
||||||
f.UserID = &userID
|
|
||||||
}
|
|
||||||
|
|
||||||
if userEmail := queryParams.Get("user_email"); userEmail != "" {
|
|
||||||
f.UserEmail = &userEmail
|
|
||||||
}
|
|
||||||
|
|
||||||
if userName := queryParams.Get("user_name"); userName != "" {
|
|
||||||
f.UserName = &userName
|
|
||||||
}
|
|
||||||
|
|
||||||
if method := queryParams.Get("method"); method != "" {
|
|
||||||
f.Method = &method
|
|
||||||
}
|
|
||||||
|
|
||||||
if status := queryParams.Get("status"); status != "" {
|
|
||||||
f.Status = &status
|
|
||||||
}
|
|
||||||
|
|
||||||
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
|
|
||||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
|
|
||||||
f.StatusCode = &statusCode
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if startDate := queryParams.Get("start_date"); startDate != "" {
|
|
||||||
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
|
|
||||||
if err == nil {
|
|
||||||
f.StartDate = &parsedStartDate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if endDate := queryParams.Get("end_date"); endDate != "" {
|
|
||||||
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
|
|
||||||
if err == nil {
|
|
||||||
f.EndDate = &parsedEndDate
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOffset calculates the database offset for pagination
|
// GetOffset calculates the database offset for pagination
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
||||||
@@ -159,3 +161,211 @@ func TestAccessLogFilter_GetLimit(t *testing.T) {
|
|||||||
limit := filter.GetLimit()
|
limit := filter.GetLimit()
|
||||||
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_FilterParams(t *testing.T) {
|
||||||
|
startDate := "2024-01-15T10:30:00Z"
|
||||||
|
endDate := "2024-01-16T15:45:00Z"
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Set("search", "test query")
|
||||||
|
q.Set("source_ip", "192.168.1.1")
|
||||||
|
q.Set("host", "example.com")
|
||||||
|
q.Set("path", "/api/users")
|
||||||
|
q.Set("user_id", "user123")
|
||||||
|
q.Set("user_email", "user@example.com")
|
||||||
|
q.Set("user_name", "John Doe")
|
||||||
|
q.Set("method", "GET")
|
||||||
|
q.Set("status", "success")
|
||||||
|
q.Set("status_code", "200")
|
||||||
|
q.Set("start_date", startDate)
|
||||||
|
q.Set("end_date", endDate)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Search)
|
||||||
|
assert.Equal(t, "test query", *filter.Search)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.SourceIP)
|
||||||
|
assert.Equal(t, "192.168.1.1", *filter.SourceIP)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Host)
|
||||||
|
assert.Equal(t, "example.com", *filter.Host)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Path)
|
||||||
|
assert.Equal(t, "/api/users", *filter.Path)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserID)
|
||||||
|
assert.Equal(t, "user123", *filter.UserID)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserEmail)
|
||||||
|
assert.Equal(t, "user@example.com", *filter.UserEmail)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserName)
|
||||||
|
assert.Equal(t, "John Doe", *filter.UserName)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Method)
|
||||||
|
assert.Equal(t, "GET", *filter.Method)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Status)
|
||||||
|
assert.Equal(t, "success", *filter.Status)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.StatusCode)
|
||||||
|
assert.Equal(t, 200, *filter.StatusCode)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.StartDate)
|
||||||
|
expectedStart, _ := time.Parse(time.RFC3339, startDate)
|
||||||
|
assert.Equal(t, expectedStart, *filter.StartDate)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.EndDate)
|
||||||
|
expectedEnd, _ := time.Parse(time.RFC3339, endDate)
|
||||||
|
assert.Equal(t, expectedEnd, *filter.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_EmptyFilters(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Nil(t, filter.Search)
|
||||||
|
assert.Nil(t, filter.SourceIP)
|
||||||
|
assert.Nil(t, filter.Host)
|
||||||
|
assert.Nil(t, filter.Path)
|
||||||
|
assert.Nil(t, filter.UserID)
|
||||||
|
assert.Nil(t, filter.UserEmail)
|
||||||
|
assert.Nil(t, filter.UserName)
|
||||||
|
assert.Nil(t, filter.Method)
|
||||||
|
assert.Nil(t, filter.Status)
|
||||||
|
assert.Nil(t, filter.StatusCode)
|
||||||
|
assert.Nil(t, filter.StartDate)
|
||||||
|
assert.Nil(t, filter.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_InvalidFilters(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Set("status_code", "invalid")
|
||||||
|
q.Set("start_date", "not-a-date")
|
||||||
|
q.Set("end_date", "2024-99-99")
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Nil(t, filter.StatusCode, "invalid status_code should be nil")
|
||||||
|
assert.Nil(t, filter.StartDate, "invalid start_date should be nil")
|
||||||
|
assert.Nil(t, filter.EndDate, "invalid end_date should be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePositiveInt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
defaultValue int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"empty string", "", 10, 10},
|
||||||
|
{"valid positive int", "25", 10, 25},
|
||||||
|
{"zero", "0", 10, 10},
|
||||||
|
{"negative", "-5", 10, 10},
|
||||||
|
{"invalid string", "abc", 10, 10},
|
||||||
|
{"float", "3.14", 10, 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parsePositiveInt(tt.input, tt.defaultValue)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *string
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid string", "hello", strPtr("hello")},
|
||||||
|
{"whitespace", " ", strPtr(" ")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalString(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalInt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *int
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid positive int", "42", intPtr(42)},
|
||||||
|
{"zero", "0", nil},
|
||||||
|
{"negative", "-10", nil},
|
||||||
|
{"invalid string", "abc", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalInt(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalRFC3339(t *testing.T) {
|
||||||
|
validDate := "2024-01-15T10:30:00Z"
|
||||||
|
expectedTime, _ := time.Parse(time.RFC3339, validDate)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *time.Time
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid RFC3339", validDate, &expectedTime},
|
||||||
|
{"invalid format", "2024-01-15", nil},
|
||||||
|
{"invalid date", "not-a-date", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalRFC3339(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for creating pointers
|
||||||
|
func strPtr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func intPtr(i int) *int {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user