mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client] Persist route selection (#2810)
This commit is contained in:
@@ -22,9 +22,28 @@ import (
|
||||
// State interface defines the methods that all state types must implement
|
||||
type State interface {
|
||||
Name() string
|
||||
}
|
||||
|
||||
// CleanableState interface extends State with cleanup capability
|
||||
type CleanableState interface {
|
||||
State
|
||||
Cleanup() error
|
||||
}
|
||||
|
||||
// RawState wraps raw JSON data for unregistered states
|
||||
type RawState struct {
|
||||
data json.RawMessage
|
||||
}
|
||||
|
||||
func (r *RawState) Name() string {
|
||||
return "" // This is a placeholder implementation
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler to preserve the original JSON
|
||||
func (r *RawState) MarshalJSON() ([]byte, error) {
|
||||
return r.data, nil
|
||||
}
|
||||
|
||||
// Manager handles the persistence and management of various states
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
@@ -209,15 +228,15 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadState loads the existing state from the state file
|
||||
func (m *Manager) loadState() error {
|
||||
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
|
||||
func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) {
|
||||
data, err := os.ReadFile(m.filePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
log.Debug("state file does not exist")
|
||||
return nil
|
||||
return nil, nil // nolint:nilnil
|
||||
}
|
||||
return fmt.Errorf("read state file: %w", err)
|
||||
return nil, fmt.Errorf("read state file: %w", err)
|
||||
}
|
||||
|
||||
var rawStates map[string]json.RawMessage
|
||||
@@ -228,37 +247,69 @@ func (m *Manager) loadState() error {
|
||||
} else {
|
||||
log.Info("State file deleted")
|
||||
}
|
||||
return fmt.Errorf("unmarshal states: %w", err)
|
||||
return nil, fmt.Errorf("unmarshal states: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
return rawStates, nil
|
||||
}
|
||||
|
||||
for name, rawState := range rawStates {
|
||||
stateType, ok := m.stateTypes[name]
|
||||
if !ok {
|
||||
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
|
||||
continue
|
||||
}
|
||||
// loadSingleRawState unmarshals a raw state into a concrete state object
|
||||
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
|
||||
stateType, ok := m.stateTypes[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("state %s not registered", name)
|
||||
}
|
||||
|
||||
if string(rawState) == "null" {
|
||||
continue
|
||||
}
|
||||
if string(rawState) == "null" {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
statePtr := reflect.New(stateType).Interface().(State)
|
||||
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
|
||||
continue
|
||||
}
|
||||
statePtr := reflect.New(stateType).Interface().(State)
|
||||
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
|
||||
}
|
||||
|
||||
m.states[name] = statePtr
|
||||
return statePtr, nil
|
||||
}
|
||||
|
||||
// LoadState loads a specific state from the state file
|
||||
func (m *Manager) LoadState(state State) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
rawStates, err := m.loadStateFile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := state.Name()
|
||||
rawState, exists := rawStates[name]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.states[name] = loadedState
|
||||
if loadedState != nil {
|
||||
log.Debugf("loaded state: %s", name)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
|
||||
// If the cleanup is successful, the state is marked for deletion.
|
||||
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
|
||||
// Unregistered states are preserved in their original state.
|
||||
func (m *Manager) PerformCleanup() error {
|
||||
if m == nil {
|
||||
return nil
|
||||
@@ -267,22 +318,53 @@ func (m *Manager) PerformCleanup() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if err := m.loadState(); err != nil {
|
||||
// Load raw states from file
|
||||
rawStates, err := m.loadStateFile()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to load state during cleanup: %v", err)
|
||||
return err
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for name, state := range m.states {
|
||||
if state == nil {
|
||||
// If no state was found in the state file, we don't mark the state dirty nor return an error
|
||||
|
||||
// Process each state in the file
|
||||
for name, rawState := range rawStates {
|
||||
// For unregistered states, preserve the raw JSON
|
||||
if _, registered := m.stateTypes[name]; !registered {
|
||||
m.states[name] = &RawState{data: rawState}
|
||||
continue
|
||||
}
|
||||
|
||||
// Load the registered state
|
||||
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if loadedState == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if state supports cleanup
|
||||
cleanableState, isCleanable := loadedState.(CleanableState)
|
||||
if !isCleanable {
|
||||
// If it doesn't support cleanup, keep it as-is
|
||||
m.states[name] = loadedState
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform cleanup for cleanable states
|
||||
log.Infof("client was not shut down properly, cleaning up %s", name)
|
||||
if err := state.Cleanup(); err != nil {
|
||||
if err := cleanableState.Cleanup(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
|
||||
// On cleanup error, preserve the state
|
||||
m.states[name] = loadedState
|
||||
} else {
|
||||
// mark for deletion on cleanup success
|
||||
// Successfully cleaned up - mark for deletion
|
||||
m.states[name] = nil
|
||||
m.dirty[name] = struct{}{}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user