mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Cleanup firewall state on startup (#2768)
This commit is contained in:
@@ -14,6 +14,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,6 +26,13 @@ const (
|
||||
chainNameInput = "INPUT"
|
||||
)
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
@@ -35,30 +44,68 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||
|
||||
m.router, err = newRouter(context, workTable, wgIface)
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Init nftables firewall manager
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router.init(workTable); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager.init(workTable); err != nil {
|
||||
// TODO: cleanup router
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Reset() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
// persist early
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
@@ -203,48 +250,80 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
if err := m.resetNetbirdInputRules(); err != nil {
|
||||
return fmt.Errorf("reset netbird input rules: %v", err)
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
}
|
||||
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
return fmt.Errorf("delete state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetNetbirdInputRules() error {
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
m.deleteNetbirdInputRules(chains)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
// delete Netbird allow input traffic rule if it exists
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.deleteMatchingRules(rules)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset forward rules: %v", err)
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNetbirdTables() error {
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
return fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.rConn.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
|
||||
Reference in New Issue
Block a user