mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
[client] Cleanup firewall state on startup (#2768)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset()
|
||||
_ = fw.Reset(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
@@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset()
|
||||
_ = fw.Reset(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
|
||||
@@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error {
|
||||
}
|
||||
|
||||
// RunWithProbes runs the client's main logic with probes attached
|
||||
func (c *ConnectClient) RunWithProbes(
|
||||
probes *ProbeHolder,
|
||||
runningChan chan error,
|
||||
) error {
|
||||
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, probes, runningChan)
|
||||
}
|
||||
|
||||
@@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(
|
||||
mobileDependency MobileDependency,
|
||||
probes *ProbeHolder,
|
||||
runningChan chan error,
|
||||
) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
|
||||
@@ -533,6 +533,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||
s.addHostRootZone()
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -366,7 +367,7 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
}
|
||||
@@ -1167,7 +1168,7 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Reset()
|
||||
err := e.firewall.Reset(e.stateManager)
|
||||
if err != nil {
|
||||
log.Warnf("failed to reset firewall: %s", err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package refcounter
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
@@ -70,6 +71,19 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
||||
}
|
||||
}
|
||||
|
||||
// LoadData loads the data from the existing counter
|
||||
func (rm *Counter[Key, I, O]) LoadData(
|
||||
existingCounter *Counter[Key, I, O],
|
||||
) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
rm.refCountMap = existingCounter.refCountMap
|
||||
rm.idMap = existingCounter.idMap
|
||||
}
|
||||
|
||||
// Get retrieves the current reference count and associated data for a key.
|
||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
@@ -201,6 +215,32 @@ func (rm *Counter[Key, I, O]) Clear() {
|
||||
clear(rm.idMap)
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
IDMap map[string][]Key `json:"idMap"`
|
||||
}{
|
||||
RefCountMap: rm.refCountMap,
|
||||
IDMap: rm.idMap,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
||||
var temp struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
IDMap map[string][]Key `json:"idMap"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
rm.refCountMap = temp.RefCountMap
|
||||
rm.idMap = temp.IDMap
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getCallerInfo(depth int, maxDepth int) (string, bool) {
|
||||
if depth >= maxDepth {
|
||||
return "", false
|
||||
|
||||
@@ -1,30 +1,15 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
type RouteEntry struct {
|
||||
Prefix netip.Prefix `json:"prefix"`
|
||||
Nexthop Nexthop `json:"nexthop"`
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewShutdownState() *ShutdownState {
|
||||
return &ShutdownState{
|
||||
Routes: make(map[netip.Prefix]RouteEntry),
|
||||
}
|
||||
Counter *ExclusionCounter `json:"counter,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -32,50 +17,16 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.Counter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sysops := NewSysOps(nil, nil)
|
||||
var merr *multierror.Error
|
||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||
sysops.refCounter.LoadData(s.Counter)
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for _, route := range s.Routes {
|
||||
if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.Routes[prefix] = RouteEntry{
|
||||
Prefix: prefix,
|
||||
Nexthop: nexthop,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.Routes, prefix)
|
||||
}
|
||||
|
||||
// MarshalJSON ensures that empty routes are marshaled as null
|
||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if len(s.Routes) == 0 {
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
|
||||
return json.Marshal(s.Routes)
|
||||
}
|
||||
|
||||
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||
return json.Unmarshal(data, &s.Routes)
|
||||
return sysops.refCounter.Flush()
|
||||
}
|
||||
|
||||
@@ -57,14 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
|
||||
r.updateState(stateManager, prefix, nexthop)
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nexthop, err
|
||||
},
|
||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
// remove from state even if we have trouble removing it from the route table
|
||||
// it could be already gone
|
||||
r.removeFromState(stateManager, prefix)
|
||||
r.updateState(stateManager)
|
||||
|
||||
return r.removeFromRouteTable(prefix, nexthop)
|
||||
},
|
||||
@@ -75,24 +75,16 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
||||
return r.setupHooks(initAddresses)
|
||||
}
|
||||
|
||||
func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) {
|
||||
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
||||
state := getState(stateManager)
|
||||
state.UpdateRoute(prefix, nexthop)
|
||||
|
||||
state.Counter = r.refCounter
|
||||
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) {
|
||||
state := getState(stateManager)
|
||||
state.RemoveRoute(prefix)
|
||||
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("Failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||
if r.refCounter == nil {
|
||||
return nil
|
||||
@@ -107,7 +99,7 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
log.Errorf("failed to delete state: %v", err)
|
||||
return fmt.Errorf("delete state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -546,7 +538,7 @@ func getState(stateManager *statemanager.Manager) *ShutdownState {
|
||||
if state := stateManager.GetState(shutdownState); state != nil {
|
||||
shutdownState = state.(*ShutdownState)
|
||||
} else {
|
||||
shutdownState = NewShutdownState()
|
||||
shutdownState = &ShutdownState{}
|
||||
}
|
||||
|
||||
return shutdownState
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||
@@ -27,7 +27,7 @@ func GetDefaultStatePath() string {
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
||||
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user