[client] Cleanup firewall state on startup (#2768)

This commit is contained in:
Viktor Liu
2024-10-24 14:46:24 +02:00
committed by GitHub
parent 4e918e55ba
commit 8016710d24
32 changed files with 739 additions and 318 deletions

View File

@@ -1,6 +1,7 @@
package refcounter
import (
"encoding/json"
"errors"
"fmt"
"runtime"
@@ -70,6 +71,19 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
}
}
// LoadData loads the data from the existing counter
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
}
// Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
@@ -201,6 +215,32 @@ func (rm *Counter[Key, I, O]) Clear() {
clear(rm.idMap)
}
// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
}{
RefCountMap: rm.refCountMap,
IDMap: rm.idMap,
})
}
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
var temp struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
rm.refCountMap = temp.RefCountMap
rm.idMap = temp.IDMap
return nil
}
func getCallerInfo(depth int, maxDepth int) (string, bool) {
if depth >= maxDepth {
return "", false

View File

@@ -1,30 +1,15 @@
package systemops
import (
"encoding/json"
"fmt"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
type RouteEntry struct {
Prefix netip.Prefix `json:"prefix"`
Nexthop Nexthop `json:"nexthop"`
}
type ShutdownState struct {
Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"`
mu sync.RWMutex
}
func NewShutdownState() *ShutdownState {
return &ShutdownState{
Routes: make(map[netip.Prefix]RouteEntry),
}
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
func (s *ShutdownState) Name() string {
@@ -32,50 +17,16 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.Counter == nil {
return nil
}
sysops := NewSysOps(nil, nil)
var merr *multierror.Error
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter)
s.mu.RLock()
defer s.mu.RUnlock()
for _, route := range s.Routes {
if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) {
s.mu.Lock()
defer s.mu.Unlock()
s.Routes[prefix] = RouteEntry{
Prefix: prefix,
Nexthop: nexthop,
}
}
func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.Routes, prefix)
}
// MarshalJSON ensures that empty routes are marshaled as null
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.Routes) == 0 {
return json.Marshal(nil)
}
return json.Marshal(s.Routes)
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &s.Routes)
return sysops.refCounter.Flush()
}

View File

@@ -57,14 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, refcounter.ErrIgnore
}
r.updateState(stateManager, prefix, nexthop)
r.updateState(stateManager)
return nexthop, err
},
func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table
// it could be already gone
r.removeFromState(stateManager, prefix)
r.updateState(stateManager)
return r.removeFromRouteTable(prefix, nexthop)
},
@@ -75,24 +75,16 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return r.setupHooks(initAddresses)
}
func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) {
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager)
state.UpdateRoute(prefix, nexthop)
state.Counter = r.refCounter
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) {
state := getState(stateManager)
state.RemoveRoute(prefix)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("Failed to update state: %v", err)
}
}
func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
if r.refCounter == nil {
return nil
@@ -107,7 +99,7 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
log.Errorf("failed to delete state: %v", err)
return fmt.Errorf("delete state: %w", err)
}
return nil
@@ -546,7 +538,7 @@ func getState(stateManager *statemanager.Manager) *ShutdownState {
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = NewShutdownState()
shutdownState = &ShutdownState{}
}
return shutdownState