[client] Persist route selection (#2810)

This commit is contained in:
Viktor Liu
2024-12-02 17:55:02 +01:00
committed by GitHub
parent ecb44ff306
commit 5142dc52c1
10 changed files with 273 additions and 40 deletions

View File

@@ -22,9 +22,28 @@ import (
// State interface defines the methods that all state types must implement
type State interface {
Name() string
}
// CleanableState interface extends State with cleanup capability
type CleanableState interface {
State
Cleanup() error
}
// RawState wraps raw JSON data for unregistered states
type RawState struct {
data json.RawMessage
}
func (r *RawState) Name() string {
return "" // This is a placeholder implementation
}
// MarshalJSON implements json.Marshaler to preserve the original JSON
func (r *RawState) MarshalJSON() ([]byte, error) {
return r.data, nil
}
// Manager handles the persistence and management of various states
type Manager struct {
mu sync.Mutex
@@ -209,15 +228,15 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}
// loadState loads the existing state from the state file
func (m *Manager) loadState() error {
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) {
data, err := os.ReadFile(m.filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist")
return nil
return nil, nil // nolint:nilnil
}
return fmt.Errorf("read state file: %w", err)
return nil, fmt.Errorf("read state file: %w", err)
}
var rawStates map[string]json.RawMessage
@@ -228,37 +247,69 @@ func (m *Manager) loadState() error {
} else {
log.Info("State file deleted")
}
return fmt.Errorf("unmarshal states: %w", err)
return nil, fmt.Errorf("unmarshal states: %w", err)
}
var merr *multierror.Error
return rawStates, nil
}
for name, rawState := range rawStates {
stateType, ok := m.stateTypes[name]
if !ok {
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
continue
}
// loadSingleRawState unmarshals a raw state into a concrete state object
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
stateType, ok := m.stateTypes[name]
if !ok {
return nil, fmt.Errorf("state %s not registered", name)
}
if string(rawState) == "null" {
continue
}
if string(rawState) == "null" {
return nil, nil //nolint:nilnil
}
statePtr := reflect.New(stateType).Interface().(State)
if err := json.Unmarshal(rawState, statePtr); err != nil {
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
continue
}
statePtr := reflect.New(stateType).Interface().(State)
if err := json.Unmarshal(rawState, statePtr); err != nil {
return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
}
m.states[name] = statePtr
return statePtr, nil
}
// LoadState loads a specific state from the state file
func (m *Manager) LoadState(state State) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile()
if err != nil {
return err
}
if rawStates == nil {
return nil
}
name := state.Name()
rawState, exists := rawStates[name]
if !exists {
return nil
}
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
return err
}
m.states[name] = loadedState
if loadedState != nil {
log.Debugf("loaded state: %s", name)
}
return nberrors.FormatErrorOrNil(merr)
return nil
}
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
// If the cleanup is successful, the state is marked for deletion.
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
// Unregistered states are preserved in their original state.
func (m *Manager) PerformCleanup() error {
if m == nil {
return nil
@@ -267,22 +318,53 @@ func (m *Manager) PerformCleanup() error {
m.mu.Lock()
defer m.mu.Unlock()
if err := m.loadState(); err != nil {
// Load raw states from file
rawStates, err := m.loadStateFile()
if err != nil {
log.Warnf("Failed to load state during cleanup: %v", err)
return err
}
if rawStates == nil {
return nil
}
var merr *multierror.Error
for name, state := range m.states {
if state == nil {
// If no state was found in the state file, we don't mark the state dirty nor return an error
// Process each state in the file
for name, rawState := range rawStates {
// For unregistered states, preserve the raw JSON
if _, registered := m.stateTypes[name]; !registered {
m.states[name] = &RawState{data: rawState}
continue
}
// Load the registered state
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if loadedState == nil {
continue
}
// Check if state supports cleanup
cleanableState, isCleanable := loadedState.(CleanableState)
if !isCleanable {
// If it doesn't support cleanup, keep it as-is
m.states[name] = loadedState
continue
}
// Perform cleanup for cleanable states
log.Infof("client was not shut down properly, cleaning up %s", name)
if err := state.Cleanup(); err != nil {
if err := cleanableState.Cleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
// On cleanup error, preserve the state
m.states[name] = loadedState
} else {
// mark for deletion on cleanup success
// Successfully cleaned up - mark for deletion
m.states[name] = nil
m.dirty[name] = struct{}{}
}