mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 03:36:41 +00:00
[client] Enable userspace forwarder conditionally (#3309)
* Enable userspace forwarder conditionally * Move disable/enable logic
This commit is contained in:
@@ -74,6 +74,8 @@ type Manager struct {
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
// indicates whether we forward packets not destined for ourselves
|
||||
routingEnabled bool
|
||||
// indicates whether we leave forwarding and filtering to the native firewall
|
||||
@@ -125,15 +127,27 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func parseCreateEnv() (bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||
disableConntrack, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
|
||||
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
disableConntrack, err := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||
}
|
||||
enableLocalForwarding, err := strconv.ParseBool(os.Getenv(EnvEnableNetstackLocalForwarding))
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||
}
|
||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
@@ -149,15 +163,16 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
return d
|
||||
},
|
||||
},
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
netstack: netstack.IsEnabled(),
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
netstack: netstack.IsEnabled(),
|
||||
// default true for non-netstack, for netstack only if explicitly enabled
|
||||
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
|
||||
}
|
||||
@@ -166,7 +181,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
}
|
||||
|
||||
// Only initialize trackers if stateful mode is enabled
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
@@ -175,7 +189,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
m.determineRouting(iface, disableServerRoutes)
|
||||
// netstack needs the forwarder for local traffic
|
||||
if m.netstack && m.localForwarding {
|
||||
if err := m.initForwarder(); err != nil {
|
||||
log.Errorf("failed to initialize forwarder: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.blockInvalidRouted(iface); err != nil {
|
||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||
@@ -213,9 +232,21 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) {
|
||||
disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting))
|
||||
forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter))
|
||||
func (m *Manager) determineRouting() error {
|
||||
var disableUspRouting, forceUserspaceRouter bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||
disableUspRouting, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
||||
forceUserspaceRouter, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case disableUspRouting:
|
||||
@@ -223,7 +254,7 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes
|
||||
m.nativeRouter = false
|
||||
log.Info("userspace routing is disabled")
|
||||
|
||||
case disableServerRoutes:
|
||||
case m.disableServerRoutes:
|
||||
// if server routes are disabled we will let packets pass to the native stack
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
@@ -252,32 +283,37 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes
|
||||
log.Info("userspace routing enabled by default")
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
if m.netstack && m.localForwarding ||
|
||||
m.routingEnabled && !m.nativeRouter {
|
||||
|
||||
m.initForwarder(iface)
|
||||
if m.routingEnabled && !m.nativeRouter {
|
||||
return m.initForwarder()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder(iface common.IFaceMapper) {
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := iface.GetWGDevice()
|
||||
if intf == nil {
|
||||
log.Info("forwarding not supported")
|
||||
m.routingEnabled = false
|
||||
return
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(iface, m.logger, m.netstack)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create forwarder: %v", err)
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := m.wgIface.GetWGDevice()
|
||||
if intf == nil {
|
||||
m.routingEnabled = false
|
||||
return
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||
if err != nil {
|
||||
m.routingEnabled = false
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
}
|
||||
|
||||
m.forwarder = forwarder
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(*statemanager.Manager) error {
|
||||
@@ -285,7 +321,7 @@ func (m *Manager) Init(*statemanager.Manager) error {
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
@@ -586,7 +622,6 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if !m.isValidPacket(d, packetData) {
|
||||
m.logger.Trace("Invalid packet structure")
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -658,11 +693,9 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
||||
return false
|
||||
}
|
||||
|
||||
// Get protocol and ports for route ACL check
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
// Check route ACLs
|
||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||
srcIP, srcPort, dstIP, dstPort, proto)
|
||||
@@ -704,12 +737,12 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
log.Tracef("couldn't decode layer, err: %s", err)
|
||||
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
log.Tracef("not enough levels in network packet")
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -953,3 +986,34 @@ func (m *Manager) SetLogLevel(level log.Level) {
|
||||
m.logger.SetLevel(nblog.Level(level))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.determineRouting()
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.forwarder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
|
||||
// don't stop forwarder if in use by netstack
|
||||
if m.netstack && m.localForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.forwarder.Stop()
|
||||
m.forwarder = nil
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -303,6 +303,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NoError(tb, err)
|
||||
require.NotNil(tb, manager)
|
||||
require.True(tb, manager.routingEnabled)
|
||||
|
||||
Reference in New Issue
Block a user