package statemanager import ( "context" "encoding/json" "errors" "fmt" "io/fs" "os" "reflect" "sync" "time" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/util" ) const ( errStateNotRegistered = "state %s not registered" errLoadStateFile = "load state file: %w" ) // State interface defines the methods that all state types must implement type State interface { Name() string } // CleanableState interface extends State with cleanup capability type CleanableState interface { State Cleanup() error } // RawState wraps raw JSON data for unregistered states type RawState struct { data json.RawMessage } func (r *RawState) Name() string { return "" // This is a placeholder implementation } // MarshalJSON implements json.Marshaler to preserve the original JSON func (r *RawState) MarshalJSON() ([]byte, error) { return r.data, nil } // Manager is the interface that exposes the persistence and state management methods. type Manager interface { Start() Stop(ctx context.Context) error RegisterState(state State) GetState(state State) State UpdateState(state State) error DeleteState(state State) error DeleteStateByName(stateName string) error DeleteAllStates() (int, error) PersistState(ctx context.Context) error LoadState(state State) error CleanupStateByName(name string) error PerformCleanup() error GetSavedStateNames() ([]string, error) } // Manager handles the persistence and management of various states type managerImpl struct { mu sync.Mutex cancel context.CancelFunc done chan struct{} filePath string // holds the states that are registered with the manager and that are to be persisted states map[string]State // holds the state names that have been updated and need to be persisted with the next save dirty map[string]struct{} // holds the type information for each registered state stateTypes map[string]reflect.Type } // New creates a new Manager instance func New(filePath string) Manager { return &managerImpl{ filePath: filePath, states: make(map[string]State), dirty: make(map[string]struct{}), stateTypes: make(map[string]reflect.Type), } } // Start starts the state manager periodic save routine func (m *managerImpl) Start() { if m == nil { return } m.mu.Lock() defer m.mu.Unlock() var ctx context.Context ctx, m.cancel = context.WithCancel(context.Background()) m.done = make(chan struct{}) go m.periodicStateSave(ctx) } func (m *managerImpl) Stop(ctx context.Context) error { if m == nil { return nil } m.mu.Lock() defer m.mu.Unlock() if m.cancel == nil { return nil } m.cancel() select { case <-ctx.Done(): return ctx.Err() case <-m.done: } return nil } // RegisterState registers a state with the manager but doesn't attempt to persist it. // Pass an uninitialized state to register it. func (m *managerImpl) RegisterState(state State) { if m == nil { return } m.mu.Lock() defer m.mu.Unlock() name := state.Name() m.states[name] = nil m.stateTypes[name] = reflect.TypeOf(state).Elem() } // GetState returns the state for the given type func (m *managerImpl) GetState(state State) State { if m == nil { return nil } m.mu.Lock() defer m.mu.Unlock() return m.states[state.Name()] } // UpdateState updates the state in the manager and marks it as dirty for the next save. // The state will be replaced with the new one. func (m *managerImpl) UpdateState(state State) error { if m == nil { return nil } return m.setState(state.Name(), state) } // DeleteState removes the state from the manager and marks it as dirty for the next save. // Pass an uninitialized state to delete it. func (m *managerImpl) DeleteState(state State) error { if m == nil { return nil } return m.setState(state.Name(), nil) } func (m *managerImpl) setState(name string, state State) error { m.mu.Lock() defer m.mu.Unlock() if _, exists := m.states[name]; !exists { return fmt.Errorf(errStateNotRegistered, name) } m.states[name] = state m.dirty[name] = struct{}{} return nil } // DeleteStateByName handles deletion of states without cleanup. // It doesn't require the state to be registered. func (m *managerImpl) DeleteStateByName(stateName string) error { if m == nil { return nil } m.mu.Lock() defer m.mu.Unlock() rawStates, err := m.loadStateFile(false) if err != nil { return fmt.Errorf(errLoadStateFile, err) } if rawStates == nil { return nil } if _, exists := rawStates[stateName]; !exists { return fmt.Errorf("state %s not found", stateName) } // Mark state as deleted by setting it to nil and marking it dirty m.states[stateName] = nil m.dirty[stateName] = struct{}{} return nil } // DeleteAllStates removes all states. func (m *managerImpl) DeleteAllStates() (int, error) { if m == nil { return 0, nil } m.mu.Lock() defer m.mu.Unlock() rawStates, err := m.loadStateFile(false) if err != nil { return 0, fmt.Errorf(errLoadStateFile, err) } if rawStates == nil { return 0, nil } count := len(rawStates) // Mark all states as deleted and dirty for name := range rawStates { m.states[name] = nil m.dirty[name] = struct{}{} } return count, nil } func (m *managerImpl) periodicStateSave(ctx context.Context) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() defer close(m.done) for { select { case <-ctx.Done(): return case <-ticker.C: if err := m.PersistState(ctx); err != nil { log.Errorf("failed to persist state: %v", err) } } } } // PersistState persists the states that have been updated since the last save. func (m *managerImpl) PersistState(ctx context.Context) error { if m == nil { return nil } m.mu.Lock() defer m.mu.Unlock() if len(m.dirty) == 0 { return nil } bs, err := marshalWithPanicRecovery(m.states) if err != nil { return fmt.Errorf("marshal states: %w", err) } ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) start := time.Now() go func() { done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs) }() select { case <-ctx.Done(): return ctx.Err() case err := <-done: if err != nil { return err } } log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) clear(m.dirty) return nil } // loadStateFile reads and unmarshals the state file into a map of raw JSON messages func (m *managerImpl) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) { data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { log.Debug("state file does not exist") return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) } var rawStates map[string]json.RawMessage if err := json.Unmarshal(data, &rawStates); err != nil { m.handleCorruptedState(deleteCorrupt) return nil, fmt.Errorf("unmarshal states: %w", err) } return rawStates, nil } // handleCorruptedState creates a backup of a corrupted state file by moving it func (m *managerImpl) handleCorruptedState(deleteCorrupt bool) { if !deleteCorrupt { return } log.Warn("State file appears to be corrupted, attempting to back it up") backupPath := fmt.Sprintf("%s.corrupted.%d", m.filePath, time.Now().UnixNano()) if err := os.Rename(m.filePath, backupPath); err != nil { log.Errorf("Failed to backup corrupted state file: %v", err) return } log.Infof("Created backup of corrupted state file at: %s", backupPath) } // loadSingleRawState unmarshals a raw state into a concrete state object func (m *managerImpl) loadSingleRawState(name string, rawState json.RawMessage) (State, error) { stateType, ok := m.stateTypes[name] if !ok { return nil, fmt.Errorf(errStateNotRegistered, name) } if string(rawState) == "null" { return nil, nil //nolint:nilnil } statePtr := reflect.New(stateType).Interface().(State) if err := json.Unmarshal(rawState, statePtr); err != nil { return nil, fmt.Errorf("unmarshal state %s: %w", name, err) } return statePtr, nil } // LoadState loads a specific state from the state file func (m *managerImpl) LoadState(state State) error { if m == nil { return nil } m.mu.Lock() defer m.mu.Unlock() rawStates, err := m.loadStateFile(false) if err != nil { return err } if rawStates == nil { return nil } name := state.Name() rawState, exists := rawStates[name] if !exists { return nil } loadedState, err := m.loadSingleRawState(name, rawState) if err != nil { return err } m.states[name] = loadedState if loadedState != nil { log.Debugf("loaded state: %s", name) } return nil } // cleanupSingleState handles the cleanup of a specific state and returns any error. // The caller must hold the mutex. func (m *managerImpl) cleanupSingleState(name string, rawState json.RawMessage) error { // For unregistered states, preserve the raw JSON if _, registered := m.stateTypes[name]; !registered { m.states[name] = &RawState{data: rawState} return nil } // Load the state loadedState, err := m.loadSingleRawState(name, rawState) if err != nil { return err } if loadedState == nil { return nil } // Check if state supports cleanup cleanableState, isCleanable := loadedState.(CleanableState) if !isCleanable { // If it doesn't support cleanup, keep it as-is m.states[name] = loadedState return nil } // Perform cleanup log.Infof("cleaning up state %s", name) if err := cleanableState.Cleanup(); err != nil { // On cleanup error, preserve the state m.states[name] = loadedState return fmt.Errorf("cleanup state: %w", err) } // Successfully cleaned up - mark for deletion m.states[name] = nil m.dirty[name] = struct{}{} return nil } // CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState. // Returns an error if the state doesn't exist, isn't registered, or cleanup fails. func (m *managerImpl) CleanupStateByName(name string) error { if m == nil { return nil } m.mu.Lock() defer m.mu.Unlock() // Check if state is registered if _, registered := m.stateTypes[name]; !registered { return fmt.Errorf(errStateNotRegistered, name) } // Load raw states from file rawStates, err := m.loadStateFile(false) if err != nil { return err } if rawStates == nil { return nil } // Check if state exists in file rawState, exists := rawStates[name] if !exists { return nil } if err := m.cleanupSingleState(name, rawState); err != nil { return fmt.Errorf("%s: %w", name, err) } return nil } // PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it. // Unregistered states are preserved in their original state. func (m *managerImpl) PerformCleanup() error { if m == nil { return nil } m.mu.Lock() defer m.mu.Unlock() // Load raw states from file rawStates, err := m.loadStateFile(true) if err != nil { return fmt.Errorf(errLoadStateFile, err) } if rawStates == nil { return nil } var merr *multierror.Error // Process each state in the file for name, rawState := range rawStates { if err := m.cleanupSingleState(name, rawState); err != nil { merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err)) } } return nberrors.FormatErrorOrNil(merr) } // GetSavedStateNames returns all state names that are currently saved in the state file. func (m *managerImpl) GetSavedStateNames() ([]string, error) { if m == nil { return nil, nil } rawStates, err := m.loadStateFile(false) if err != nil { return nil, fmt.Errorf(errLoadStateFile, err) } if rawStates == nil { return nil, nil } var states []string for name, state := range rawStates { if len(state) != 0 && string(state) != "null" { states = append(states, name) } } return states, nil } func marshalWithPanicRecovery(v any) ([]byte, error) { var bs []byte var err error func() { defer func() { if r := recover(); r != nil { err = fmt.Errorf("panic during marshal: %v", r) } }() bs, err = json.Marshal(v) }() return bs, err }