diff --git a/.env b/.env index c43c7a1..62b57b8 100644 --- a/.env +++ b/.env @@ -68,4 +68,10 @@ BASELINE_SUPPRESS_FOR=6h #BASELINE_MIN_COUNT=20 #BASELINE_MEDIUM_Z=3.0 #BASELINE_HIGH_Z=5.0 -#BASELINE_SUPPRESS_FOR=4h \ No newline at end of file +#BASELINE_SUPPRESS_FOR=4h + +PARTITION_MAINTENANCE_ENABLED=true +PARTITION_MAINTENANCE_INTERVAL=15m +PARTITION_INTERVAL=3h +PARTITION_AHEAD=24h +PARTITION_BEHIND=6h \ No newline at end of file diff --git a/main.go b/main.go index ed1be28..0f985c5 100644 --- a/main.go +++ b/main.go @@ -1163,6 +1163,13 @@ type Config struct { UEBALookback time.Duration UEBANewContextWindow time.Duration RiskScoreWindow time.Duration + + PartitionMaintenanceEnabled bool + PartitionMaintenanceInterval time.Duration + PartitionInterval time.Duration + PartitionAhead time.Duration + PartitionBehind time.Duration + PartitionRetention time.Duration } type LogPayload struct { @@ -1490,6 +1497,11 @@ type EventCountBucketAgg struct { LastTS time.Time } +type partitionedTable struct { + Name string + TimeColumn string +} + var ( httpRequestsTotal = prometheus.NewCounterVec( prometheus.CounterOpts{Name: "eventcollector_http_requests_total", Help: "Total HTTP requests."}, @@ -1699,10 +1711,11 @@ func main() { s.templates = tmpl go s.runSOCLoop() - go s.runDetectionLoop() go s.runBaselineLoop() + go s.runPartitionMaintenanceLoop() + mux := http.NewServeMux() mux.HandleFunc("/healthz", s.handleHealthz) mux.HandleFunc("/readyz", s.handleReadyz) @@ -1983,6 +1996,279 @@ func (s *server) handleUIPrivilegedUsers(w http.ResponseWriter, r *http.Request) }) } +func (s *server) runPartitionMaintenanceLoop() { + if !s.cfg.PartitionMaintenanceEnabled { + s.logger.Printf("partition maintenance disabled") + return + } + + s.runPartitionMaintenanceOnce() + + ticker := time.NewTicker(s.cfg.PartitionMaintenanceInterval) + defer ticker.Stop() + + for range ticker.C { + s.runPartitionMaintenanceOnce() + } +} + +func (s *server) runPartitionMaintenanceOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + start := time.Now() + + if err := s.ensureConfiguredPartitions(ctx); err != nil { + s.logger.Printf("partition maintenance error after %s: %v", time.Since(start), err) + return + } + + s.logger.Printf("partition maintenance completed in %s", time.Since(start)) +} + +func (s *server) ensureTableIsPartitioned(ctx context.Context, tableName string) error { + var partitionCount int + + err := s.db.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM information_schema.PARTITIONS +WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = ? + AND PARTITION_NAME IS NOT NULL +`, tableName).Scan(&partitionCount) + + if err != nil { + return fmt.Errorf("check partitioned table %s: %w", tableName, err) + } + + if partitionCount == 0 { + return fmt.Errorf( + "table %s is not partitioned; run the SQL migration first", + tableName, + ) + } + + return nil +} + +func (s *server) ensure3HourPartitions(ctx context.Context, tbl partitionedTable) error { + interval := s.cfg.PartitionInterval + if interval <= 0 { + interval = 3 * time.Hour + } + + ahead := s.cfg.PartitionAhead + if ahead <= 0 { + ahead = 24 * time.Hour + } + + behind := s.cfg.PartitionBehind + if behind < 0 { + behind = 0 + } + + now := time.Now().UTC() + + start := partitionFloor(now.Add(-behind), interval) + end := partitionFloor(now.Add(ahead), interval).Add(interval) + + for pStart := start; pStart.Before(end); pStart = pStart.Add(interval) { + pEnd := pStart.Add(interval) + + exists, err := s.partitionExists(ctx, tbl.Name, partitionName(pStart)) + if err != nil { + return err + } + if exists { + continue + } + + if err := s.addPartitionBeforePMax(ctx, tbl.Name, pStart, pEnd); err != nil { + return err + } + } + + return nil +} + +func partitionFloor(t time.Time, interval time.Duration) time.Time { + t = t.UTC() + + if interval <= 0 { + interval = 3 * time.Hour + } + + seconds := int64(interval.Seconds()) + unix := t.Unix() + floored := unix - (unix % seconds) + + return time.Unix(floored, 0).UTC() +} + +func partitionName(start time.Time) string { + return "p" + start.UTC().Format("2006010215") +} + +func mysqlDateTimeLiteral(t time.Time) string { + return t.UTC().Format("2006-01-02 15:04:05") +} + +func (s *server) partitionExists(ctx context.Context, tableName, partitionName string) (bool, error) { + var count int + + err := s.db.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM information_schema.PARTITIONS +WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = ? + AND PARTITION_NAME = ? +`, tableName, partitionName).Scan(&count) + + if err != nil { + return false, fmt.Errorf("check partition exists %s.%s: %w", tableName, partitionName, err) + } + + return count > 0, nil +} + +func (s *server) addPartitionBeforePMax(ctx context.Context, tableName string, start, end time.Time) error { + pName := partitionName(start) + endLit := mysqlDateTimeLiteral(end) + + if !safeIdentifier(tableName) || !safeIdentifier(pName) { + return fmt.Errorf("unsafe partition/table identifier: table=%q partition=%q", tableName, pName) + } + + query := fmt.Sprintf(` +ALTER TABLE %s +REORGANIZE PARTITION pmax INTO ( + PARTITION %s VALUES LESS THAN ('%s'), + PARTITION pmax VALUES LESS THAN (MAXVALUE) +) +`, tableName, pName, endLit) + + s.logger.Printf("creating partition table=%s partition=%s less_than=%s", tableName, pName, endLit) + + if _, err := s.db.ExecContext(ctx, query); err != nil { + return fmt.Errorf("create partition %s on %s: %w", pName, tableName, err) + } + + return nil +} + +func safeIdentifier(v string) bool { + if v == "" { + return false + } + + for _, r := range v { + if r >= 'a' && r <= 'z' { + continue + } + if r >= 'A' && r <= 'Z' { + continue + } + if r >= '0' && r <= '9' { + continue + } + if r == '_' { + continue + } + return false + } + + return true +} + +func (s *server) dropOldPartitions(ctx context.Context, tableName string, retention time.Duration) error { + if retention <= 0 { + return nil + } + + cutoff := partitionFloor(time.Now().UTC().Add(-retention), s.cfg.PartitionInterval) + + rows, err := s.db.QueryContext(ctx, ` +SELECT PARTITION_NAME +FROM information_schema.PARTITIONS +WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = ? + AND PARTITION_NAME IS NOT NULL + AND PARTITION_NAME <> 'pmax' +`, tableName) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var pName string + if err := rows.Scan(&pName); err != nil { + return err + } + + pStart, ok := parsePartitionName(pName) + if !ok { + continue + } + + if !pStart.Before(cutoff) { + continue + } + + if !safeIdentifier(tableName) || !safeIdentifier(pName) { + return fmt.Errorf("unsafe identifier while dropping partition: %s.%s", tableName, pName) + } + + query := fmt.Sprintf(`ALTER TABLE %s DROP PARTITION %s`, tableName, pName) + + s.logger.Printf("dropping old partition table=%s partition=%s", tableName, pName) + + if _, err := s.db.ExecContext(ctx, query); err != nil { + return fmt.Errorf("drop partition %s on %s: %w", pName, tableName, err) + } + } + + return rows.Err() +} + +func parsePartitionName(name string) (time.Time, bool) { + if len(name) != len("p2006010215") || !strings.HasPrefix(name, "p") { + return time.Time{}, false + } + + t, err := time.ParseInLocation("2006010215", strings.TrimPrefix(name, "p"), time.UTC) + if err != nil { + return time.Time{}, false + } + + return t.UTC(), true +} + +func (s *server) ensureConfiguredPartitions(ctx context.Context) error { + tables := []partitionedTable{ + {Name: "event_logs", TimeColumn: "ts"}, + {Name: "event_logs_raw", TimeColumn: "ts"}, + } + + for _, tbl := range tables { + if err := s.ensureTableIsPartitioned(ctx, tbl.Name); err != nil { + return err + } + + if err := s.ensure3HourPartitions(ctx, tbl); err != nil { + return err + } + + if s.cfg.PartitionRetention > 0 { + if err := s.dropOldPartitions(ctx, tbl.Name, s.cfg.PartitionRetention); err != nil { + return err + } + } + } + + return nil +} + func (s *server) handleUIPrivilegedUserSave(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed") @@ -2127,7 +2413,7 @@ WHERE id IN ( SELECT id FROM ( SELECT id FROM detections - WHERE 1=1 + WHERE 1=10 ` args := []any{ status, @@ -3538,6 +3824,13 @@ func loadConfig() Config { UEBALookback: getenvDuration("UEBA_LOOKBACK", 30*24*time.Hour), UEBANewContextWindow: getenvDuration("UEBA_NEW_CONTEXT_WINDOW", 10*time.Minute), RiskScoreWindow: getenvDuration("RISK_SCORE_WINDOW", 24*time.Hour), + + PartitionMaintenanceEnabled: getenvBool("PARTITION_MAINTENANCE_ENABLED", true), + PartitionMaintenanceInterval: getenvDuration("PARTITION_MAINTENANCE_INTERVAL", 15*time.Minute), + PartitionInterval: getenvDuration("PARTITION_INTERVAL", 3*time.Hour), + PartitionAhead: getenvDuration("PARTITION_AHEAD", 24*time.Hour), + PartitionBehind: getenvDuration("PARTITION_BEHIND", 6*time.Hour), + PartitionRetention: getenvDuration("PARTITION_RETENTION", 30*24*time.Hour), } }