Compare commits

...

34 Commits

Author SHA1 Message Date
Viktor Liu
ff5eddf70b Merge branch 'main' into add-ns-punnycode-support 2025-06-08 13:14:52 +02:00
Maycon Santos
0f050e5fe1 [client] Optmize process check time (#3938)
This PR optimizes the process check time by updating the implementation of getRunningProcesses and introducing new benchmark tests.

Updated getRunningProcesses to use process.Pids() instead of process.Processes()
Added benchmark tests for both the new and the legacy implementations

Benchmark: https://github.com/netbirdio/netbird/actions/runs/15512741612

todo: evaluate windows optmizations and caching risks
2025-06-08 12:19:54 +02:00
Maycon Santos
0f7c7f1da2 [misc] use generic slack url (#3939) 2025-06-08 10:53:27 +02:00
Maycon Santos
b56f61bf1b [misc] fix relay exposed address test (#3931) 2025-06-05 15:44:44 +02:00
Viktor Liu
64f111923e [client] Increase stun status probe timeout (#3930) 2025-06-05 15:22:59 +02:00
Abdul Latif
122a89c02b [misc] remove error causing dnf config-manager add (#3925) 2025-06-05 14:28:19 +02:00
Robert Neumann
c6cceba381 Update getting-started-with-zitadel.sh - fix zitadel user console (#3446) 2025-06-05 14:16:04 +02:00
Ghazy Abdallah
6c0cdb6ed1 [misc] fix: traefik relay accessibility (#3696) 2025-06-05 14:15:01 +02:00
Viktor Liu
84354951d3 [client] Add systemd netbird logs to debug bundle (#3917) 2025-06-05 13:54:15 +02:00
Viktor Liu
55957a1960 [client] Run registerdns before flushing (#3926)
* Run registerdns before flushing

* Disable WINS, dynamic updates and registration
2025-06-05 12:40:23 +02:00
Viktor Liu
df82a45d99 [client] Improve dns match trace log (#3928) 2025-06-05 12:39:58 +02:00
Zoltan Papp
9424b88db2 [client] Add output similar to wg show to the debug package (#3922) 2025-06-05 11:51:39 +02:00
Viktor Liu
609654eee7 [client] Allow userspace local forwarding to internal interfaces if requested (#3884) 2025-06-04 18:12:48 +02:00
Bethuel Mmbaga
b604c66140 [management] Add postgres support for activity event store (#3890) 2025-06-04 17:38:49 +03:00
Viktor Liu
ea4d13e96d [client] Use platform-native routing APIs for freeBSD, macOS and Windows 2025-06-04 16:28:58 +02:00
Pedro Maia Costa
87148c503f [management] support account retrieval and creation by private domain (#3825)
* [management] sys initiator save user (#3911)

* [management] activity events with multiple external account users (#3914)
2025-06-04 11:21:31 +01:00
Viktor Liu
0cd36baf67 [client] Allow the netbird service to log to console (#3916) 2025-06-03 13:09:39 +02:00
Viktor Liu
06980e7fa0 [client] Apply routes right away instead of on peer connection (#3907) 2025-06-03 10:53:39 +02:00
Viktor Liu
1ce4ee0cef [client] Add block inbound flag to disallow inbound connections of any kind (#3897) 2025-06-03 10:53:27 +02:00
Viktor Liu
f367925496 [client] Log duplicate client ui pid (#3915) 2025-06-03 10:52:10 +02:00
hakansa
616b19c064 [client] Add "Deselect All" Menu Item to Exit Node Menu (#3877)
* [client] Enhance exit node menu functionality with deselect all option

* Hide exit nodes before removal in recreateExitNodeMenu

* recreateExitNodeMenu adding mutex locks

* Refetch exit nodes after deselecting all in exit node menu
2025-06-03 09:49:13 +02:00
Zoltan Papp
af27aaf9af [client] Refactor peer state change subscription mechanism (#3910)
* Refactor peer state change subscription mechanism

Because the code generated new channel for every single event, was easy to miss notification.
Use single channel.

* Fix lint

* Avoid potential deadlock

* Fix test

* Add context

* Fix test
2025-06-03 09:20:33 +02:00
Maycon Santos
35287f8241 [misc] Fail linter workflows on codespell failures (#3913)
* Fail linter workflows on codespell failures

* testing workflow

* remove test
2025-06-03 00:37:51 +02:00
Pedro Maia Costa
07b220d91b [management] REST client impersonation (#3879) 2025-06-02 22:11:28 +02:00
Viktor Liu
41cd4952f1 [client] Apply return traffic rules only if firewall is stateless (#3895) 2025-06-02 12:11:54 +02:00
Zoltan Papp
f16f0c7831 [client] Fix HA router switch (#3889)
* Fix HA router switch.

- Simplify the notification filter logic.
Always send notification if a state has been changed

- Remove IP changes check because we never modify

* Notify only the proper listeners

* Fix test

* Fix TestGetPeerStateChangeNotifierLogic test

* Before lazy connection, when the peer disconnected, the status switched to disconnected.
After implementing lazy connection, the peer state is connecting, so we did not decrease the reference counters on the routes.

* When switch to idle notify the route mgr
2025-06-01 16:08:27 +02:00
Viktor Liu
273160c682 [client] Use punycode domains internally consequently (#3867) 2025-05-24 18:25:15 +02:00
bcmmbaga
1d6c360aec fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-23 13:07:26 +03:00
bcmmbaga
f04e7c3f06 Merge branch 'main' into add-ns-punnycode-support
# Conflicts:
#	management/server/nameserver.go
#	management/server/nameserver_test.go
2025-05-23 13:00:19 +03:00
bcmmbaga
3d89cd43c2 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 22:44:30 +03:00
bcmmbaga
0eeda712d0 add support for punycode domain
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 22:44:12 +03:00
bcmmbaga
3e3268db5f Remove support for wildcard ns match domain
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 19:01:53 +03:00
bcmmbaga
31f0879e71 remove the leading dot and root dot support ns regex
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 18:51:05 +03:00
bcmmbaga
f25b5bb987 Enhance match domain validation logic and add test cases
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 16:35:45 +03:00
151 changed files with 4489 additions and 2702 deletions

View File

@@ -21,7 +21,6 @@ jobs:
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1
golangci: golangci:
strategy: strategy:
fail-fast: false fail-fast: false

View File

@@ -172,11 +172,11 @@ jobs:
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values # check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$' grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445" grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"' grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true' grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0 grep LoginFlag management.json | grep 0

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g"> <a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br> <br>
@@ -29,7 +29,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a> Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
<br/> <br/>
</strong> </strong>

View File

@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
return a.ipAnonymizer[ip] return a.ipAnonymizer[ip]
} }
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
// Convert IP to netip.Addr
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return addr
}
anonIP := a.AnonymizeIP(ip)
return net.UDPAddr{
IP: anonIP.AsSlice(),
Port: addr.Port,
Zone: addr.Zone,
}
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs // isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 { if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {

View File

@@ -39,7 +39,6 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
enableLazyConnectionFlag = "enable-lazy-connection" enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle" uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url" uploadBundleURL = "upload-bundle-url"
@@ -78,7 +77,6 @@ var (
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
blockLANAccess bool
debugUploadBundle bool debugUploadBundle bool
debugUploadBundleURL string debugUploadBundleURL string
lazyConnEnabled bool lazyConnEnabled bool

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"runtime"
"sync" "sync"
"github.com/kardianos/service" "github.com/kardianos/service"
@@ -27,12 +28,19 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
} }
func newSVCConfig() *service.Config { func newSVCConfig() *service.Config {
return &service.Config{ config := &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.", Description: "Netbird mesh network client",
Option: make(service.KeyValue), Option: make(service.KeyValue),
EnvVars: make(map[string]string),
} }
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
} }
func newSVC(prg *program, conf *service.Config) (service.Service, error) { func newSVC(prg *program, conf *service.Config) (service.Service, error) {

View File

@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
} }
if logFile != "console" { if logFile != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
} }

View File

@@ -6,6 +6,8 @@ const (
disableServerRoutesFlag = "disable-server-routes" disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns" disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall" disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
) )
var ( var (
@@ -13,6 +15,8 @@ var (
disableServerRoutes bool disableServerRoutes bool
disableDNS bool disableDNS bool
disableFirewall bool disableFirewall bool
blockLANAccess bool
blockInbound bool
) )
func init() { func init() {
@@ -28,4 +32,11 @@ func init() {
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false, upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.") "Disable firewall configuration. If enabled, the client won't modify firewall rules.")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
} }

View File

@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: ` Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0 netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3), Args: cobra.ExactArgs(3),
RunE: tracePacket, RunE: tracePacket,

View File

@@ -55,12 +55,11 @@ func init() {
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
) )
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval") upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil, upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+ `Sets DNS labels`+
@@ -119,83 +118,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
ic := internal.ConfigInput{ ic, err := setupConfig(customDNSAddressConverted, cmd)
ManagementURL: managementURL, if err != nil {
AdminURL: adminURL, return fmt.Errorf("setup config: %v", err)
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
} }
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
@@ -203,7 +128,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(*ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
} }
@@ -262,9 +187,141 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
if err != nil { if err != nil {
return err return fmt.Errorf("get setup key: %v", err)
} }
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup login request: %v", err)
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
ic.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
return &ic, nil
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
@@ -301,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return err return nil, err
} }
loginRequest.InterfaceName = &interfaceName loginRequest.InterfaceName = &interfaceName
} }
@@ -336,49 +393,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.BlockLanAccess = &blockLANAccess loginRequest.BlockLanAccess = &blockLANAccess
} }
if cmd.Flag(blockInboundFlag).Changed {
loginRequest.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed { if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled loginRequest.LazyConnectionEnabled = &lazyConnEnabled
} }
return &loginRequest, nil
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
} }
func validateNATExternalIPs(list []string) error { func validateNATExternalIPs(list []string) error {

View File

@@ -147,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -198,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil, nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
"all", firewall.ProtocolALL,
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
@@ -219,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }

View File

@@ -2,7 +2,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -248,10 +248,6 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -278,10 +274,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if pair.Masquerade { if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)

View File

@@ -116,6 +116,8 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering( AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,

View File

@@ -170,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -324,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }

View File

@@ -3,7 +3,6 @@ package nftables
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1").Unmap()
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: add.AsSlice(), Data: ip.AsSlice(),
}, },
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
}) })
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(

View File

@@ -573,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -1006,10 +1002,6 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }

View File

@@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip net.IP ip tcpip.Address
netstack bool netstack bool
} }
@@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("failed to create NIC: %v", err)
} }
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: ones, PrefixLen: iface.Address().Network.Bits(),
}, },
} }
@@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: iface.Address().IP, ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) { if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1) return net.IPv4(127, 0, 0, 1)
} }
return addr.AsSlice() return addr.AsSlice()
@@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
} }
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil { if errInToOut != nil {
if !isClosedError(errInToOut) { if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut) f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
} }
} }
if errOutToIn != nil { if errOutToIn != nil {
if !isClosedError(errOutToIn) { if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn) f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
} }
} }

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait() wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) { if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr) f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
} }
if inboundErr != nil && !isClosedError(inboundErr) { if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr) f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
} }
var rxPackets, txPackets uint64 var rxPackets, txPackets uint64

View File

@@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
} }
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if ipv4 := ip.To4(); ipv4 != nil { if !ip.Is4() {
high := uint16(ipv4[0]) return
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) }
ipv4 := ip.AsSlice()
if bitmap[high] == nil { high := uint16(ipv4[0])
bitmap[high] = &ipv4LowBitmap{} low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
}
index := low / 32 if bitmap[high] == nil {
bit := low % 32 bitmap[high] = &ipv4LowBitmap{}
bitmap[high].bitmap[index] |= 1 << bit }
ipStr := ipv4.String() index := low / 32
if _, exists := ipv4Set[ipStr]; !exists { bit := low % 32
ipv4Set[ipStr] = struct{}{} bitmap[high].bitmap[index] |= 1 << bit
*ipv4Addresses = append(*ipv4Addresses, ipStr)
} if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
} }
} }
@@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
} }
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil return nil
} }
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue continue
} }
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil { addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err) log.Debugf("process IP failed: %v", err)
} }
} }
@@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}() }()
var newIPv4Bitmap [256]*ipv4LowBitmap var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{}) ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []string var ipv4Addresses []netip.Addr
// 127.0.0.0/8 // 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{} newIPv4Bitmap[127] = &ipv4LowBitmap{}

View File

@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
@@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
@@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match - addresses 32 apart", name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.33"), testIP: netip.MustParseAddr("192.168.1.33"),
expected: false, expected: false,
@@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: netip.MustParseAddr("fe80::1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,

View File

@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@@ -39,8 +39,12 @@ const (
// EnvForceUserspaceRouter forces userspace routing even if native routing is available. // EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack // EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible // Default off as it might be security risk because sockets listening on localhost only will become accessible.
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
@@ -71,7 +75,6 @@ type Manager struct {
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@@ -148,6 +151,11 @@ func parseCreateEnv() (bool, bool) {
if err != nil { if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
} }
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
} }
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
@@ -269,7 +277,7 @@ func (m *Manager) determineRouting() error {
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported(): case !m.netstack && m.nativeFirewall != nil:
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
@@ -326,6 +334,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
@@ -606,9 +618,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true return true
} }
if m.stateful { // for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, size)
}
return false return false
} }
@@ -777,9 +788,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true return true
} }
// if running in netstack mode we need to pass this to the forwarder // If requested we pass local traffic to internal interfaces to the forwarder.
if m.netstack && m.localForwarding { // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
return m.handleNetstackLocalTraffic(packetData) if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
return m.handleForwardedLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
@@ -789,8 +801,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return false return false
} }
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool { func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
fwd := m.forwarder.Load() fwd := m.forwarder.Load()
if fwd == nil { if fwd == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
@@ -1088,11 +1099,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true return true
} }
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not

View File

@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup // Apply scenario-specific setup
sc.setupFunc(manager) sc.setupFunc(manager)
@@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table // Pre-populate connection table
srcIPs := generateRandomIPs(count) srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count)
@@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0] srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "post_handshake", state: "post_handshake",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err) require.NoError(b, err)
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept) dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@@ -19,12 +19,8 @@ import (
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
localIP := net.ParseIP("100.10.0.100") localIP := netip.MustParseAddr("100.10.0.100")
wgNet := &net.IPNet{ wgNet := netip.MustParsePrefix("100.10.0.0/16")
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
@@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl) dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes() dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
localIP, wgNet, err := net.ParseCIDR(network) wgNet := netip.MustParsePrefix(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: localIP, IP: wgNet.Addr(),
Network: wgNet, Network: wgNet,
} }
}, },
@@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
@@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
@@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{

View File

@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) { if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }

View File

@@ -12,6 +12,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var zeroKey wgtypes.Key
type KernelConfigurer struct { type KernelConfigurer struct {
deviceName string deviceName string
} }
@@ -201,6 +203,47 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() { func (c *KernelConfigurer) Close() {
} }
func (c *KernelConfigurer) FullStats() (*Stats, error) {
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
fullStats := &Stats{
DeviceName: wgDevice.Name,
PublicKey: wgDevice.PublicKey.String(),
ListenPort: wgDevice.ListenPort,
FWMark: wgDevice.FirewallMark,
Peers: []Peer{},
}
for _, p := range wgDevice.Peers {
peer := Peer{
PublicKey: p.PublicKey.String(),
AllowedIPs: p.AllowedIPs,
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint
}
fullStats.Peers = append(fullStats.Peers, peer)
}
return fullStats, nil
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) { func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats) stats := make(map[string]WGStats)
wg, err := wgctrl.New() wg, err := wgctrl.New()

View File

@@ -19,10 +19,17 @@ import (
) )
const ( const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec" ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec" ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes" ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes" ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
) )
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -186,6 +193,15 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
ipcStr, err := c.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("IpcGet failed: %w", err)
}
return parseStatus(c.deviceName, ipcStr)
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() { func (t *WGUSPConfigurer) startUAPI() {
var err error var err error
@@ -365,3 +381,136 @@ func getFwmark() int {
} }
return 0 return 0
} }
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
// Decode hex string to bytes
keyBytes, err := hex.DecodeString(hexKey)
if err != nil {
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
}
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
if len(keyBytes) != 32 {
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
}
// Convert to wgtypes.Key
var key wgtypes.Key
copy(key[:], keyBytes)
return key, nil
}
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
stats := &Stats{DeviceName: deviceName}
var currentPeer *Peer
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
val := parts[1]
switch key {
case privateKey:
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
continue
}
stats.PublicKey = key.PublicKey().String()
case publicKey:
// Save previous peer
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse public key: %v", err)
continue
}
currentPeer = &Peer{
PublicKey: key.String(),
}
case listenPort:
if port, err := strconv.Atoi(val); err == nil {
stats.ListenPort = port
}
case fwmark:
if fwmark, err := strconv.Atoi(val); err == nil {
stats.FWMark = fwmark
}
case endpoint:
if currentPeer == nil {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Errorf("failed to parse endpoint port: %v", err)
continue
}
currentPeer.Endpoint = net.UDPAddr{
IP: net.ParseIP(host),
Port: port,
}
case allowedIP:
if currentPeer == nil {
continue
}
_, ipnet, err := net.ParseCIDR(val)
if err == nil {
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
}
case ipcKeyTxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.TxBytes = rxBytes
case ipcKeyRxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.RxBytes = rxBytes
case ipcKeyLastHandshakeTimeSec:
if currentPeer == nil {
continue
}
ts, err := toLastHandshake(val)
if err != nil {
continue
}
currentPeer.LastHandshake = ts
case presharedKey:
if currentPeer == nil {
continue
}
if val != "" {
currentPeer.PresharedKey = true
}
}
}
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
return stats, nil
}

View File

@@ -0,0 +1,24 @@
package configurer
import (
"net"
"time"
)
type Peer struct {
PublicKey string
Endpoint net.UDPAddr
AllowedIPs []net.IPNet
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
}
type Stats struct {
DeviceName string
PublicKey string
ListenPort int
FWMark int
Peers []Peer
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
@@ -43,11 +44,11 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
} }
} }
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { func (t *WGTunDevice) Create(routes []string, dns string, searchDomains domain.List) (WGConfigurer, error) {
log.Info("create tun interface") log.Info("create tun interface")
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains.ToPunycodeList())
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {

View File

@@ -1,7 +1,6 @@
package device package device
import ( import (
"net"
"net/netip" "net/netip"
"sync" "sync"
@@ -24,9 +23,6 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
} }
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets

View File

@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1) dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@@ -17,4 +17,5 @@ type WGConfigurer interface {
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error
Close() Close()
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
} }

View File

@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
} }
ip := address.IP.String() ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@@ -8,10 +8,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains domain.List) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address

View File

@@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
} }
w.filter = filter w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter) w.tun.FilteredDevice().SetFilter(filter)
return nil return nil
@@ -217,6 +216,10 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats() return w.configurer.GetStats()
} }
func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime) timeout := time.NewTimer(maxWaitTime)

View File

@@ -2,7 +2,11 @@
package iface package iface
import "fmt" import (
"fmt"
"github.com/netbirdio/netbird/management/domain"
)
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
@@ -21,6 +25,6 @@ func (w *WGIface) Create() error {
} }
// CreateOnAndroid this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
return fmt.Errorf("this function has not implemented on non mobile") return fmt.Errorf("this function has not implemented on non mobile")
} }

View File

@@ -2,11 +2,13 @@ package iface
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/management/domain"
) )
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains domain.List) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()

View File

@@ -7,6 +7,8 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/management/domain"
) )
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
@@ -36,6 +38,6 @@ func (w *WGIface) Create() error {
} }
// CreateOnAndroid this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
return fmt.Errorf("this function has not implemented on this platform") return fmt.Errorf("this function has not implemented on this platform")
} }

View File

@@ -5,7 +5,6 @@
package mocks package mocks
import ( import (
net "net"
"net/netip" "net/netip"
reflect "reflect" reflect "reflect"
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
} }
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -1,8 +1,6 @@
package netstack package netstack
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"strconv" "strconv"
@@ -15,8 +13,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address net.IP address netip.Addr
dnsAddress net.IP dnsAddress netip.Addr
mtu int mtu int
listenAddress string listenAddress string
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, address: address,
dnsAddress: dnsAddress, dnsAddress: dnsAddress,
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
} }
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{addr.Unmap()}, []netip.Addr{t.address},
[]netip.Addr{dnsAddr.Unmap()}, []netip.Addr{t.dnsAddress},
t.mtu) t.mtu)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -2,28 +2,27 @@ package wgaddr
import ( import (
"fmt" "fmt"
"net" "net/netip"
) )
// Address WireGuard parsed address // Address WireGuard parsed address
type Address struct { type Address struct {
IP net.IP IP netip.Addr
Network *net.IPNet Network netip.Prefix
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) prefix, err := netip.ParsePrefix(address)
if err != nil { if err != nil {
return Address{}, err return Address{}, err
} }
return Address{ return Address{
IP: ip, IP: prefix.Addr().Unmap(),
Network: network, Network: prefix.Masked(),
}, nil }, nil
} }
func (addr Address) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -58,6 +58,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
@@ -69,14 +74,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total) time.Since(start), total)
}() }()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)
} }
@@ -285,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete. if d.firewall.IsStateful() {
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")

View File

@@ -1,13 +1,14 @@
package acl package acl
import ( import (
"net" "net/netip"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { require.NoError(t, err)
t.Errorf("create firewall: %v", err) defer func() {
return err = fw.Close(nil)
} require.NoError(t, err)
defer func(fw manager.Manager) { }()
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 2 { if fw.IsStateful() {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) assert.Equal(t, 0, len(acl.peerRulesPairs))
return } else {
assert.Equal(t, 2, len(acl.peerRulesPairs))
} }
}) })
@@ -94,12 +91,13 @@ func TestDefaultManager(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
// we should have one old and one new rule in the existed rules expectedRules := 2
if len(acl.peerRulesPairs) != 2 { if fw.IsStateful() {
t.Errorf("firewall rules not applied") expectedRules = 1 // only the inbound rule
return
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
previousCount++ previousCount++
} }
} }
if previousCount != 1 {
t.Errorf("old rule was not removed") expectedPreviousCount := 0
if !fw.IsStateful() {
expectedPreviousCount = 1
} }
assert.Equal(t, expectedPreviousCount, previousCount)
}) })
t.Run("handle default rules", func(t *testing.T) { t.Run("handle default rules", func(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 { acl.ApplyFiltering(networkMap, false)
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) assert.Equal(t, 0, len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) expectedRules := 1
return if fw.IsStateful() {
expectedRules = 1 // only inbound allow-all rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
})
}
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
// In stateless mode, we should have both inbound and outbound rules
assert.False(t, fw.IsStateful())
assert.Equal(t, 2, len(acl.peerRulesPairs))
}) })
} }
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
manager := &DefaultManager{} manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap) rules, _ := manager.squashAcceptRules(networkMap)
if len(rules) != 2 { assert.Equal(t, 2, len(rules))
t.Errorf("rules should contain 2, got: %v", rules)
return
}
r := rules[0] r := rules[0]
switch { assert.Equal(t, "0.0.0.0", r.PeerIP)
case r.PeerIP != "0.0.0.0": assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
return assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1] r = rules[1]
switch { assert.Equal(t, "0.0.0.0", r.PeerIP)
case r.PeerIP != "0.0.0.0": assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
return assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
} }
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
@@ -291,9 +326,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
} }
manager := &DefaultManager{} manager := &DefaultManager{}
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) { rules, _ := manager.squashAcceptRules(networkMap)
t.Errorf("we should get the same amount of rules as output, got %v", len(rules)) assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
} }
func TestDefaultManagerEnableSSHRules(t *testing.T) { func TestDefaultManagerEnableSSHRules(t *testing.T) {
@@ -336,33 +370,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { require.NoError(t, err)
t.Errorf("create firewall: %v", err) defer func() {
return err = fw.Close(nil)
} require.NoError(t, err)
defer func(fw manager.Manager) { }()
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 3 { expectedRules := 3
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) if fw.IsStateful() {
return expectedRules = 3 // 2 inbound rules + SSH rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
} }

View File

@@ -68,8 +68,8 @@ type ConfigInput struct {
DisableServerRoutes *bool DisableServerRoutes *bool
DisableDNS *bool DisableDNS *bool
DisableFirewall *bool DisableFirewall *bool
BlockLANAccess *bool
BlockLANAccess *bool BlockInbound *bool
DisableNotifications *bool DisableNotifications *bool
@@ -98,8 +98,8 @@ type Config struct {
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool
BlockLANAccess bool BlockInbound bool
DisableNotifications *bool DisableNotifications *bool
@@ -483,6 +483,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
if *input.BlockInbound {
log.Infof("blocking inbound connections")
} else {
log.Infof("allowing inbound connections")
}
config.BlockInbound = *input.BlockInbound
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications { if *input.DisableNotifications {
log.Infof("disabling notifications") log.Infof("disabling notifications")

View File

@@ -436,11 +436,12 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DNSRouteInterval: config.DNSRouteInterval, DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes, DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes, DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
DisableDNS: config.DisableDNS, DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall, DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess,
BlockInbound: config.BlockInbound,
BlockLANAccess: config.BlockLANAccess,
LazyConnectionEnabled: config.LazyConnectionEnabled, LazyConnectionEnabled: config.LazyConnectionEnabled,
} }

View File

@@ -270,11 +270,21 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err) log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
} }
if g.logFile != "console" { if err := g.addWgShow(); err != nil {
if err := g.addLogfile(); err != nil { log.Errorf("Failed to add wg show output: %v", err)
return fmt.Errorf("add log file: %w", err)
}
} }
if g.logFile != "console" && g.logFile != "" {
if err := g.addLogfile(); err != nil {
log.Errorf("Failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs: %v", err)
}
return nil return nil
} }
@@ -366,17 +376,33 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled)) configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive)) configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
if g.internalConfig.ServerSSHAllowed != nil { if g.internalConfig.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed)) configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
} }
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS)) configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall)) configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess)) configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
if g.internalConfig.DisableNotifications != nil {
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
}
configContent.WriteString(fmt.Sprintf("DNSLabels: %v\n", g.internalConfig.DNSLabels))
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
if g.internalConfig.ClientCertPath != "" {
configContent.WriteString(fmt.Sprintf("ClientCertPath: %s\n", g.internalConfig.ClientCertPath))
}
if g.internalConfig.ClientCertKeyPath != "" {
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
} }

View File

@@ -4,17 +4,104 @@ package debug
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sort" "sort"
"strings" "strings"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
maxLogEntries = 100000
maxLogAge = 7 * 24 * time.Hour // Last 7 days
)
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
func (g *BundleGenerator) trySystemdLogFallback() error {
log.Debug("Attempting to collect systemd journal logs")
serviceName := getServiceName()
journalLogs, err := getSystemdLogs(serviceName)
if err != nil {
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
}
if strings.Contains(journalLogs, "No recent log entries found") {
log.Debug("No recent log entries found in systemd journal")
return nil
}
if g.anonymize {
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
}
logReader := strings.NewReader(journalLogs)
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
if err := g.addFileToZip(logReader, fileName); err != nil {
return fmt.Errorf("add systemd logs to bundle: %w", err)
}
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
return nil
}
// getServiceName gets the service name from environment or defaults to netbird
func getServiceName() string {
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
return unitName
}
return "netbird"
}
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
func getSystemdLogs(serviceName string) (string, error) {
args := []string{
"-u", fmt.Sprintf("%s.service", serviceName),
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
"--lines", fmt.Sprintf("%d", maxLogEntries),
"--no-pager",
"--output", "short-iso",
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "journalctl", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
}
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("journalctl command not found: %w", err)
}
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
}
logs := stdout.String()
if strings.TrimSpace(logs) == "" {
return "No recent log entries found in systemd journal", nil
}
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
serviceName, maxLogEntries, maxLogAge.String())
return header + logs, nil
}
// addFirewallRules collects and adds firewall rules to the archive // addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error { func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules") log.Info("Collecting firewall rules")
@@ -481,7 +568,7 @@ func formatExpr(exp expr.Any) string {
case *expr.Fib: case *expr.Fib:
return formatFib(e) return formatFib(e)
case *expr.Target: case *expr.Target:
return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets return fmt.Sprintf("jump %s", e.Name)
case *expr.Immediate: case *expr.Immediate:
if e.Register == 1 { if e.Register == 1 {
return formatImmediateData(e.Data) return formatImmediateData(e.Data)

View File

@@ -6,3 +6,9 @@ package debug
func (g *BundleGenerator) addFirewallRules() error { func (g *BundleGenerator) addFirewallRules() error {
return nil return nil
} }
func (g *BundleGenerator) trySystemdLogFallback() error {
// Systemd is only available on Linux
// TODO: Add BSD support
return nil
}

View File

@@ -0,0 +1,66 @@
package debug
import (
"bytes"
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGIface interface {
FullStats() (*configurer.Stats, error)
}
func (g *BundleGenerator) addWgShow() error {
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err
}
output := g.toWGShowFormat(result)
reader := bytes.NewReader([]byte(output))
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
return fmt.Errorf("add wg show to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
if s.FWMark != 0 {
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
}
for _, peer := range s.Peers {
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
if peer.Endpoint.IP != nil {
if g.anonymize {
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
} else {
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
}
}
if len(peer.AllowedIPs) > 0 {
var ipStrings []string
for _, ipnet := range peer.AllowedIPs {
ipStrings = append(ipStrings, ipnet.String())
}
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
sb.WriteString(" preshared key: (hidden)\n")
}
}
return sb.String()
}

View File

@@ -2,7 +2,7 @@ package internal
import ( import (
"fmt" "fmt"
"net" "net/netip"
"slices" "slices"
"strings" "strings"
@@ -12,13 +12,14 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) { func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData) ip, err := netip.ParseAddr(aRecord.RData)
if ip == nil || ip.To4() == nil { if err != nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
if !ipNet.Contains(ip) { if !prefix.Contains(ip) {
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
@@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
} }
// generateReverseZoneName creates the reverse DNS zone name for a given network // generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) { func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask) networkIP := network.Masked().Addr()
maskOnes, _ := ipNet.Mask.Size()
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
// round up to nearest byte // round up to nearest byte
octetsToUse := (maskOnes + 7) / 8 octetsToUse := (network.Bits() + 7) / 8
octets := strings.Split(networkIP.String(), ".") octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) { if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes) return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
} }
reverseOctets := make([]string, octetsToUse) reverseOctets := make([]string, octetsToUse)
@@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
} }
// collectPTRRecords gathers all PTR records for the given network from A records // collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord { func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones { for _, zone := range config.CustomZones {
@@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
continue continue
} }
if ptrRecord, ok := createPTRRecord(record, ipNet); ok { if ptrRecord, ok := createPTRRecord(record, prefix); ok {
records = append(records, ptrRecord) records = append(records, ptrRecord)
} }
} }
@@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
} }
// addReverseZone adds a reverse DNS zone to the configuration for the given network // addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { func addReverseZone(config *nbdns.Config, network netip.Prefix) {
zoneName, err := generateReverseZoneName(ipNet) zoneName, err := generateReverseZoneName(network)
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
return return
@@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
return return
} }
records := collectPTRRecords(config, ipNet) records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{ reverseZone := nbdns.CustomZone{
Domain: zoneName, Domain: zoneName,

View File

@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
continue continue
} }
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, ".")) listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
} }
return listOfDomains return listOfDomains
} }

View File

@@ -1,12 +1,15 @@
package dns package dns
import ( import (
"fmt"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
) )
const ( const (
@@ -23,8 +26,8 @@ type SubdomainMatcher interface {
type HandlerEntry struct { type HandlerEntry struct {
Handler dns.Handler Handler dns.Handler
Priority int Priority int
Pattern string Pattern domain.Domain
OrigPattern string OrigPattern domain.Domain
IsWildcard bool IsWildcard bool
MatchSubdomains bool MatchSubdomains bool
} }
@@ -38,7 +41,7 @@ type HandlerChain struct {
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain // ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct { type ResponseWriterChain struct {
dns.ResponseWriter dns.ResponseWriter
origPattern string origPattern domain.Domain
shouldContinue bool shouldContinue bool
} }
@@ -58,18 +61,18 @@ func NewHandlerChain() *HandlerChain {
} }
// GetOrigPattern returns the original pattern of the handler that wrote the response // GetOrigPattern returns the original pattern of the handler that wrote the response
func (w *ResponseWriterChain) GetOrigPattern() string { func (w *ResponseWriterChain) GetOrigPattern() domain.Domain {
return w.origPattern return w.origPattern
} }
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) { func (c *HandlerChain) AddHandler(pattern domain.Domain, handler dns.Handler, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern)) pattern = domain.Domain(strings.ToLower(dns.Fqdn(pattern.PunycodeString())))
origPattern := pattern origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.") isWildcard := strings.HasPrefix(pattern.PunycodeString(), "*.")
if isWildcard { if isWildcard {
pattern = pattern[2:] pattern = pattern[2:]
} }
@@ -109,8 +112,8 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
// domain specificity next // domain specificity next
if h.Priority == newEntry.Priority { if h.Priority == newEntry.Priority {
newDots := strings.Count(newEntry.Pattern, ".") newDots := strings.Count(newEntry.Pattern.PunycodeString(), ".")
existingDots := strings.Count(h.Pattern, ".") existingDots := strings.Count(h.Pattern.PunycodeString(), ".")
if newDots > existingDots { if newDots > existingDots {
return i return i
} }
@@ -122,20 +125,20 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
} }
// RemoveHandler removes a handler for the given pattern and priority // RemoveHandler removes a handler for the given pattern and priority
func (c *HandlerChain) RemoveHandler(pattern string, priority int) { func (c *HandlerChain) RemoveHandler(pattern domain.Domain, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = dns.Fqdn(pattern) pattern = domain.Domain(dns.Fqdn(pattern.PunycodeString()))
c.removeEntry(pattern, priority) c.removeEntry(pattern, priority)
} }
func (c *HandlerChain) removeEntry(pattern string, priority int) { func (c *HandlerChain) removeEntry(pattern domain.Domain, priority int) {
// Find and remove handlers matching both original pattern (case-insensitive) and priority // Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern.PunycodeString(), pattern.PunycodeString()) && entry.Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break break
} }
@@ -148,61 +151,42 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
qname := strings.ToLower(r.Question[0].Name) qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock() c.mu.RLock()
handlers := slices.Clone(c.handlers) handlers := slices.Clone(c.handlers)
c.mu.RUnlock() c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) { if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers)) var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers { for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority) h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
} }
log.Trace(strings.TrimSuffix(b.String(), "\n"))
} }
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
var matched bool matched := c.isHandlerMatch(qname, entry)
switch {
case entry.Pattern == ".": if matched {
matched = true log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
case entry.IsWildcard: qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) chainWriter := &ResponseWriterChain{
default: ResponseWriter: w,
// For non-wildcard patterns: origPattern: entry.OrigPattern,
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = strings.EqualFold(qname, entry.Pattern)
} }
} entry.Handler.ServeDNS(chainWriter, r)
if !matched { // If handler wants to continue, try next handler
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false", if chainWriter.shouldContinue {
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority) log.Tracef("handler requested continue to next handler for domain=%s", qname)
continue continue
}
return
} }
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler")
continue
}
return
} }
// No handler matched or all handlers passed // No handler matched or all handlers passed
@@ -213,3 +197,22 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
} }
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern.PunycodeString()), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern.PunycodeString())
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
return strings.EqualFold(qname, entry.Pattern.PunycodeString()) || strings.HasSuffix(qname, "."+entry.Pattern.PunycodeString())
} else {
return strings.EqualFold(qname, entry.Pattern.PunycodeString())
}
}
}

View File

@@ -9,6 +9,7 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/management/domain"
) )
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@@ -50,8 +51,8 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
handlerDomain string handlerDomain domain.Domain
queryDomain string queryDomain domain.Domain
isWildcard bool isWildcard bool
matchSubdomains bool matchSubdomains bool
shouldMatch bool shouldMatch bool
@@ -141,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
chain.AddHandler(pattern, handler, nbdns.PriorityDefault) chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
@@ -160,17 +161,17 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
handlers []struct { handlers []struct {
pattern string pattern domain.Domain
priority int priority int
} }
queryDomain string queryDomain domain.Domain
expectedCalls int expectedCalls int
expectedHandler int // index of the handler that should be called expectedHandler int // index of the handler that should be called
}{ }{
{ {
name: "wildcard and exact same priority - exact should win", name: "wildcard and exact same priority - exact should win",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -183,7 +184,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "higher priority wildcard over lower priority exact", name: "higher priority wildcard over lower priority exact",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "example.com.", priority: nbdns.PriorityDefault}, {pattern: "example.com.", priority: nbdns.PriorityDefault},
@@ -196,7 +197,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "multiple wildcards different priorities", name: "multiple wildcards different priorities",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -210,7 +211,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "subdomain with mix of patterns", name: "subdomain with mix of patterns",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -224,7 +225,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "root zone with specific domain", name: "root zone with specific domain",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: ".", priority: nbdns.PriorityDefault}, {pattern: ".", priority: nbdns.PriorityDefault},
@@ -258,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request // Create and execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
@@ -330,7 +331,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name string name string
ops []struct { ops []struct {
action string // "add" or "remove" action string // "add" or "remove"
pattern string pattern domain.Domain
priority int priority int
} }
query string query string
@@ -340,7 +341,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove high priority keeps lower priority handler", name: "remove high priority keeps lower priority handler",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -357,7 +358,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove lower priority keeps high priority handler", name: "remove lower priority keeps high priority handler",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -374,7 +375,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove all handlers in order", name: "remove all handlers in order",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -436,7 +437,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
func TestHandlerChain_MultiPriorityHandling(t *testing.T) { func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
testDomain := "example.com." testDomain := domain.Domain("example.com.")
testQuery := "test.example.com." testQuery := "test.example.com."
// Create handlers with MatchSubdomains enabled // Create handlers with MatchSubdomains enabled
@@ -518,7 +519,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name string name string
scenario string scenario string
addHandlers []struct { addHandlers []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -530,7 +531,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "case insensitive exact match", name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase", scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -544,7 +545,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "case insensitive wildcard match", name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case", scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -558,7 +559,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "multiple handlers different case same domain", name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference", scenario: "second handler should replace first despite case difference",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -573,7 +574,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "subdomain matching case insensitive", name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case", scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -587,7 +588,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "root zone case insensitive", name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case", scenario: "root zone handler should match regardless of case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -601,7 +602,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "multiple handlers different priority", name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences", scenario: "should call higher priority handler despite case differences",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -618,7 +619,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called handlerCalls := make(map[domain.Domain]bool) // track which patterns were called
// Add handlers according to test case // Add handlers according to test case
for _, h := range tt.addHandlers { for _, h := range tt.addHandlers {
@@ -686,19 +687,19 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario string scenario string
ops []struct { ops []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
} }
query string query domain.Domain
expectedMatch string expectedMatch domain.Domain
}{ }{
{ {
name: "more specific domain matches first", name: "more specific domain matches first",
scenario: "sub.example.com should match before example.com", scenario: "sub.example.com should match before example.com",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -713,7 +714,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "sub.example.com should match before example.com", scenario: "sub.example.com should match before example.com",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -728,7 +729,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "after removing most specific, should fall back to less specific", scenario: "after removing most specific, should fall back to less specific",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -745,7 +746,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "less specific domain with higher priority should match first", scenario: "less specific domain with higher priority should match first",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -760,7 +761,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "with equal priority, more specific domain should match", scenario: "with equal priority, more specific domain should match",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -776,7 +777,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "specific domain should match before wildcard at same priority", scenario: "specific domain should match before wildcard at same priority",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -791,7 +792,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
handlers := make(map[string]*nbdns.MockSubdomainHandler) handlers := make(map[domain.Domain]*nbdns.MockSubdomainHandler)
for _, op := range tt.ops { for _, op := range tt.ops {
if op.action == "add" { if op.action == "add" {
@@ -804,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
} }
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup handler expectations // Setup handler expectations
@@ -836,9 +837,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
addPattern string addPattern domain.Domain
removePattern string removePattern domain.Domain
queryPattern string queryPattern domain.Domain
shouldBeRemoved bool shouldBeRemoved bool
description string description string
}{ }{
@@ -954,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
handler := &nbdns.MockHandler{} handler := &nbdns.MockHandler{}
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA) r.SetQuestion(tt.queryPattern.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// First verify no handler is called before adding any // First verify no handler is called before adding any

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
) )
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
@@ -39,9 +40,9 @@ type HostDNSConfig struct {
} }
type DomainConfig struct { type DomainConfig struct {
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
Domain string `json:"domain"` Domain domain.Domain `json:"domain"`
MatchOnly bool `json:"matchOnly"` MatchOnly bool `json:"matchOnly"`
} }
type mockHostConfigurator struct { type mockHostConfigurator struct {
@@ -103,18 +104,20 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
config.RouteAll = true config.RouteAll = true
} }
for _, domain := range nsConfig.Domains { for _, d := range nsConfig.Domains {
d := strings.ToLower(dns.Fqdn(d.PunycodeString()))
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(domain)), Domain: domain.Domain(d),
MatchOnly: !nsConfig.SearchDomainsEnabled, MatchOnly: !nsConfig.SearchDomainsEnabled,
}) })
} }
} }
for _, customZone := range dnsConfig.CustomZones { for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) d := strings.ToLower(dns.Fqdn(customZone.Domain))
matchOnly := strings.HasSuffix(d, ipv4ReverseZone) || strings.HasSuffix(d, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)), Domain: domain.Domain(d),
MatchOnly: matchOnly, MatchOnly: matchOnly,
}) })
} }

View File

@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, ".")) matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
continue continue
} }
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, ".")) searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain.PunycodeString(), "."))
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)

View File

@@ -1,11 +1,14 @@
package dns package dns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os/exec"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -41,6 +44,20 @@ const (
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
registrationEnabledKey = "RegistrationEnabled"
maxNumberOfAddressesToRegisterKey = "MaxNumberOfAddressesToRegister"
// NetBIOS/WINS settings
netbtInterfacePath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces`
netbiosOptionsKey = "NetbiosOptions"
// NetBIOS option values: 0 = from DHCP, 1 = enabled, 2 = disabled
netbiosFromDHCP = 0
netbiosEnabled = 1
netbiosDisabled = 2
// RP_FORCE: Reapply all policies even if no policy change was detected // RP_FORCE: Reapply all policies even if no policy change was detected
rpForce = 0x1 rpForce = 0x1
) )
@@ -67,16 +84,85 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
log.Infof("detected GPO DNS policy configuration, using policy store") log.Infof("detected GPO DNS policy configuration, using policy store")
} }
return &registryConfigurator{ configurator := &registryConfigurator{
guid: guid, guid: guid,
gpo: useGPO, gpo: useGPO,
}, nil }
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
return configurator, nil
} }
func (r *registryConfigurator) supportCustomPort() bool { func (r *registryConfigurator) supportCustomPort() bool {
return false return false
} }
func (r *registryConfigurator) configureInterface() error {
var merr *multierror.Error
if err := r.disableDNSRegistrationForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable DNS registration: %w", err))
}
if err := r.disableWINSForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable WINS: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableDNSRegistrationForInterface() error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closer(regKey)
var merr *multierror.Error
if err := regKey.SetDWordValue(disableDynamicUpdateKey, 1); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", disableDynamicUpdateKey, err))
}
if err := regKey.SetDWordValue(registrationEnabledKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", registrationEnabledKey, err))
}
if err := regKey.SetDWordValue(maxNumberOfAddressesToRegisterKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", maxNumberOfAddressesToRegisterKey, err))
}
if merr == nil || len(merr.Errors) == 0 {
log.Infof("disabled DNS registration for interface %s", r.guid)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableWINSForInterface() error {
netbtKeyPath := fmt.Sprintf(`%s\Tcpip_%s`, netbtInterfacePath, r.guid)
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
regKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create NetBT interface key %s: %w", netbtKeyPath, err)
}
}
defer closer(regKey)
// NetbiosOptions: 2 = disabled
if err := regKey.SetDWordValue(netbiosOptionsKey, netbiosDisabled); err != nil {
return fmt.Errorf("set %s: %w", netbiosOptionsKey, err)
}
log.Infof("disabled WINS/NetBIOS for interface %s", r.guid)
return nil
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if config.RouteAll { if config.RouteAll {
if err := r.addDNSSetupForAll(config.ServerIP); err != nil { if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
@@ -100,9 +186,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
continue continue
} }
if !dConf.MatchOnly { if !dConf.MatchOnly {
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, ".")) searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
} }
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
@@ -119,9 +205,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
} }
if err := r.flushDNSCache(); err != nil { go r.flushDNSCache()
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }
@@ -191,7 +275,25 @@ func (r *registryConfigurator) string() string {
return "registry" return "registry"
} }
func (r *registryConfigurator) flushDNSCache() error { func (r *registryConfigurator) registerDNS() {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
// nolint:misspell
cmd := exec.CommandContext(ctx, "ipconfig", "/registerdns")
out, err := cmd.CombinedOutput()
if err != nil {
log.Errorf("failed to register DNS: %v, output: %s", err, out)
return
}
log.Info("registered DNS names")
}
func (r *registryConfigurator) flushDNSCache() {
r.registerDNS()
// dnsFlushResolverCacheFn.Call() may panic if the func is not found // dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() { defer func() {
if rec := recover(); rec != nil { if rec := recover(); rec != nil {
@@ -202,13 +304,14 @@ func (r *registryConfigurator) flushDNSCache() error {
ret, _, err := dnsFlushResolverCacheFn.Call() ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 { if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) { if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("DnsFlushResolverCache failed: %w", err) log.Errorf("DnsFlushResolverCache failed: %v", err)
return
} }
return fmt.Errorf("DnsFlushResolverCache failed") log.Errorf("DnsFlushResolverCache failed")
return
} }
log.Info("flushed DNS cache") log.Info("flushed DNS cache")
return nil
} }
func (r *registryConfigurator) updateSearchDomains(domains []string) error { func (r *registryConfigurator) updateSearchDomains(domains []string) error {
@@ -263,9 +366,7 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := r.flushDNSCache(); err != nil { go r.flushDNSCache()
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }

View File

@@ -62,8 +62,8 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
return fmt.Errorf("method UpdateDNSServer is not implemented") return fmt.Errorf("method UpdateDNSServer is not implemented")
} }
func (m *MockServer) SearchDomains() []string { func (m *MockServer) SearchDomains() domain.List {
return make([]string, 0) return make(domain.List, 0)
} }
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface

View File

@@ -125,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, "~."+dConf.Domain) matchDomains = append(matchDomains, "~."+dConf.Domain.PunycodeString())
continue continue
} }
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
} }
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic

View File

@@ -1,21 +1,19 @@
package dns package dns
import ( import (
"reflect"
"sort"
"sync" "sync"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/management/domain"
) )
type notifier struct { type notifier struct {
listener listener.NetworkChangeListener listener listener.NetworkChangeListener
listenerMux sync.Mutex listenerMux sync.Mutex
searchDomains []string searchDomains domain.List
} }
func newNotifier(initialSearchDomains []string) *notifier { func newNotifier(initialSearchDomains domain.List) *notifier {
sort.Strings(initialSearchDomains)
return &notifier{ return &notifier{
searchDomains: initialSearchDomains, searchDomains: initialSearchDomains,
} }
@@ -27,16 +25,8 @@ func (n *notifier) setListener(listener listener.NetworkChangeListener) {
n.listener = listener n.listener = listener
} }
func (n *notifier) onNewSearchDomains(searchDomains []string) { func (n *notifier) onNewSearchDomains(searchDomains domain.List) {
sort.Strings(searchDomains) if searchDomains.Equal(n.searchDomains) {
if len(n.searchDomains) != len(searchDomains) {
n.searchDomains = searchDomains
n.notify()
return
}
if reflect.DeepEqual(n.searchDomains, searchDomains) {
return return
} }

View File

@@ -44,12 +44,12 @@ type Server interface {
DnsIP() string DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string SearchDomains() domain.List
ProbeAvailability() ProbeAvailability()
} }
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
domain string domain domain.Domain
groups []*nbdns.NameServerGroup groups []*nbdns.NameServerGroup
} }
@@ -90,7 +90,7 @@ type handlerWithStop interface {
} }
type handlerWrapper struct { type handlerWrapper struct {
domain string domain domain.Domain
handler handlerWithStop handler handlerWithStop
priority int priority int
} }
@@ -197,7 +197,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.registerHandler(domains.ToPunycodeList(), handler, priority) s.registerHandler(domains, handler, priority)
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain // TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
for _, domain := range domains { for _, domain := range domains {
@@ -207,7 +207,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
s.applyHostConfig() s.applyHostConfig()
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains domain.List, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority) log.Debugf("registering handler %s with priority %d", handler, priority)
for _, domain := range domains { for _, domain := range domains {
@@ -224,7 +224,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.deregisterHandler(domains.ToPunycodeList(), priority) s.deregisterHandler(domains, priority)
for _, domain := range domains { for _, domain := range domains {
zone := toZone(domain) zone := toZone(domain)
s.extraDomains[zone]-- s.extraDomains[zone]--
@@ -235,7 +235,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
s.applyHostConfig() s.applyHostConfig()
} }
func (s *DefaultServer) deregisterHandler(domains []string, priority int) { func (s *DefaultServer) deregisterHandler(domains domain.List, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority) log.Debugf("deregistering handler %v with priority %d", domains, priority)
for _, domain := range domains { for _, domain := range domains {
@@ -378,8 +378,8 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
return nil return nil
} }
func (s *DefaultServer) SearchDomains() []string { func (s *DefaultServer) SearchDomains() domain.List {
var searchDomains []string var searchDomains domain.List
for _, dConf := range s.currentConfig.Domains { for _, dConf := range s.currentConfig.Domains {
if dConf.Disabled { if dConf.Disabled {
@@ -472,24 +472,22 @@ func (s *DefaultServer) applyHostConfig() {
config := s.currentConfig config := s.currentConfig
existingDomains := make(map[string]struct{}) existingDomains := make(map[domain.Domain]struct{})
for _, d := range config.Domains { for _, d := range config.Domains {
existingDomains[d.Domain] = struct{}{} existingDomains[d.Domain] = struct{}{}
} }
// add extra domains only if they're not already in the config // add extra domains only if they're not already in the config
for domain := range s.extraDomains { for d := range s.extraDomains {
domainStr := domain.PunycodeString() if _, exists := existingDomains[d]; !exists {
if _, exists := existingDomains[domainStr]; !exists {
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: domainStr, Domain: d,
MatchOnly: true, MatchOnly: true,
}) })
} }
} }
log.Debugf("extra match domains: %v", s.extraDomains) log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err) log.Errorf("failed to apply DNS host manager update: %v", err)
@@ -525,7 +523,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
} }
muxUpdates = append(muxUpdates, handlerWrapper{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: domain.Domain(customZone.Domain),
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}) })
@@ -647,7 +645,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests // this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap { for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority) s.deregisterHandler(domain.List{existing.domain}, existing.priority)
existing.handler.Stop() existing.handler.Stop()
} }
@@ -658,7 +656,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
if update.domain == nbdns.RootZone { if update.domain == nbdns.RootZone {
containsRootUpdate = true containsRootUpdate = true
} }
s.registerHandler([]string{update.domain}, update.handler, update.priority) s.registerHandler(domain.List{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update muxUpdateMap[update.handler.ID()] = update
} }
@@ -687,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
handler dns.Handler, handler dns.Handler,
priority int, priority int,
) (deactivate func(error), reactivate func()) { ) (deactivate func(error), reactivate func()) {
var removeIndex map[string]int var removeIndex map[domain.Domain]int
deactivate = func(err error) { deactivate = func(err error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -695,20 +693,20 @@ func (s *DefaultServer) upstreamCallbacks(
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
l.Info("Temporarily deactivating nameservers group due to timeout") l.Info("Temporarily deactivating nameservers group due to timeout")
removeIndex = make(map[string]int) removeIndex = make(map[domain.Domain]int)
for _, domain := range nsGroup.Domains { for _, domain := range nsGroup.Domains {
removeIndex[domain] = -1 removeIndex[domain] = -1
} }
if nsGroup.Primary { if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1 removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
s.deregisterHandler([]string{nbdns.RootZone}, priority) s.deregisterHandler(domain.List{nbdns.RootZone}, priority)
} }
for i, item := range s.currentConfig.Domains { for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found { if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true s.currentConfig.Domains[i].Disabled = true
s.deregisterHandler([]string{item.Domain}, priority) s.deregisterHandler(domain.List{item.Domain}, priority)
removeIndex[item.Domain] = i removeIndex[item.Domain] = i
} }
} }
@@ -732,12 +730,12 @@ func (s *DefaultServer) upstreamCallbacks(
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
for domain, i := range removeIndex { for d, i := range removeIndex {
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != d{
continue continue
} }
s.currentConfig.Domains[i].Disabled = false s.currentConfig.Domains[i].Disabled = false
s.registerHandler([]string{domain}, handler, priority) s.registerHandler(domain.List{d}, handler, priority)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
@@ -745,7 +743,7 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.registerHandler([]string{nbdns.RootZone}, handler, priority) s.registerHandler(domain.List{nbdns.RootZone}, handler, priority)
} }
s.applyHostConfig() s.applyHostConfig()
@@ -777,7 +775,7 @@ func (s *DefaultServer) addHostRootZone() {
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) s.registerHandler(domain.List{nbdns.RootZone}, handler, PriorityDefault)
} }
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
@@ -792,7 +790,7 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
state := peer.NSGroupState{ state := peer.NSGroupState{
ID: generateGroupKey(group), ID: generateGroupKey(group),
Servers: servers, Servers: servers,
Domains: group.Domains, Domains: group.Domains.ToPunycodeList(),
// The probe will determine the state, default enabled // The probe will determine the state, default enabled
Enabled: true, Enabled: true,
Error: nil, Error: nil,
@@ -825,7 +823,7 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
// groupNSGroupsByDomain groups nameserver groups by their match domains // groupNSGroupsByDomain groups nameserver groups by their match domains
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain { func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
domainMap := make(map[string][]*nbdns.NameServerGroup) domainMap := make(map[domain.Domain][]*nbdns.NameServerGroup)
for _, group := range nsGroups { for _, group := range nsGroups {
if group.Primary { if group.Primary {

View File

@@ -6,7 +6,6 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
@@ -46,10 +45,9 @@ func (w *mocWGIface) Name() string {
} }
func (w *mocWGIface) Address() wgaddr.Address { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return wgaddr.Address{ return wgaddr.Address{
IP: ip, IP: netip.MustParseAddr("100.66.100.1"),
Network: network, Network: netip.MustParsePrefix("100.66.100.0/24"),
} }
} }
@@ -97,7 +95,7 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger()) formatter.SetTextFormatter(log.StandardLogger())
} }
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { func generateDummyHandler(domain domain.Domain, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string var srvs []string
for _, srv := range servers { for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv)) srvs = append(srvs, getNSHostPort(srv))
@@ -152,7 +150,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
NameServerGroups: []*nbdns.NameServerGroup{ NameServerGroups: []*nbdns.NameServerGroup{
{ {
Domains: []string{"netbird.io"}, Domains: domain.List{"netbird.io"},
NameServers: nameServers, NameServers: nameServers,
}, },
{ {
@@ -184,7 +182,7 @@ func TestUpdateDNSServer(t *testing.T) {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@@ -202,7 +200,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
NameServerGroups: []*nbdns.NameServerGroup{ NameServerGroups: []*nbdns.NameServerGroup{
{ {
Domains: []string{"netbird.io"}, Domains: domain.List{"netbird.io"},
NameServers: nameServers, NameServers: nameServers,
}, },
}, },
@@ -303,8 +301,8 @@ func TestUpdateDNSServer(t *testing.T) {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: domain.Domain(zoneRecords[0].Name),
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
@@ -319,8 +317,8 @@ func TestUpdateDNSServer(t *testing.T) {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: domain.Domain(zoneRecords[0].Name),
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
@@ -464,17 +462,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil { if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err) t.Errorf("set packet filter: %v", err)
@@ -501,7 +492,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
dnsServer.dnsMuxMap = registeredHandlerMap{ dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: domain.Domain(zoneRecords[0].Name),
handler: &local.Resolver{}, handler: &local.Resolver{},
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
@@ -533,7 +524,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}, },
NameServerGroups: []*nbdns.NameServerGroup{ NameServerGroups: []*nbdns.NameServerGroup{
{ {
Domains: []string{"netbird.io"}, Domains: domain.List{"netbird.io"},
NameServers: nameServers, NameServers: nameServers,
}, },
{ {
@@ -599,7 +590,7 @@ func TestDNSServerStartStop(t *testing.T) {
t.Error(err) t.Error(err)
} }
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1) dnsServer.registerHandler(domain.List{"netbird.cloud"}, dnsServer.localResolver, 1)
resolver := &net.Resolver{ resolver := &net.Resolver{
PreferGo: true, PreferGo: true,
@@ -659,48 +650,48 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
var domainsUpdate string var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{} domains := domain.List{}
for _, item := range config.Domains { for _, item := range config.Domains {
if item.Disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.Domain) domains = append(domains, item.Domain)
} }
domainsUpdate = strings.Join(domains, ",") domainsUpdate = domains.PunycodeString()
return nil return nil
} }
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{ deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"}, Domains: domain.List{"domain1"},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
}, },
}, nil, 0) }, nil, 0)
deactivate(nil) deactivate(nil)
expected := "domain0,domain2" expected := "domain0, domain2"
domains := []string{} domains := domain.List{}
for _, item := range server.currentConfig.Domains { for _, item := range server.currentConfig.Domains {
if item.Disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.Domain) domains = append(domains, item.Domain)
} }
got := strings.Join(domains, ",") got := domains.PunycodeString()
if expected != got { if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got) t.Errorf("expected domains list: %q, got %q", expected, got)
} }
reactivate() reactivate()
expected = "domain0,domain1,domain2" expected = "domain0, domain1, domain2"
domains = []string{} domains = domain.List{}
for _, item := range server.currentConfig.Domains { for _, item := range server.currentConfig.Domains {
if item.Disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.Domain) domains = append(domains, item.Domain)
} }
got = strings.Join(domains, ",") got = domains.PunycodeString()
if expected != got { if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate) t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
} }
@@ -868,7 +859,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53, Port: 53,
}, },
}, },
Domains: []string{"google.com"}, Domains: domain.List{"google.com"},
Primary: false, Primary: false,
}, },
}, },
@@ -1123,7 +1114,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
name string name string
initialHandlers registeredHandlerMap initialHandlers registeredHandlerMap
updates []handlerWrapper updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain expectedHandlers map[string]domain.Domain // map[HandlerID]domain
description string description string
}{ }{
{ {
@@ -1139,7 +1130,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 1, priority: PriorityMatchDomain - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group2": "example.com", "upstream-group2": "example.com",
}, },
description: "When group1 is not included in the update, it should be removed while group2 remains", description: "When group1 is not included in the update, it should be removed while group2 remains",
@@ -1157,7 +1148,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
}, },
description: "When group2 is not included in the update, it should be removed while group1 remains", description: "When group2 is not included in the update, it should be removed while group1 remains",
@@ -1190,7 +1181,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 1, priority: PriorityMatchDomain - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
"upstream-group2": "example.com", "upstream-group2": "example.com",
"upstream-group3": "example.com", "upstream-group3": "example.com",
@@ -1225,7 +1216,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 2, priority: PriorityMatchDomain - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
"upstream-group2": "example.com", "upstream-group2": "example.com",
"upstream-group3": "example.com", "upstream-group3": "example.com",
@@ -1245,7 +1236,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 1, priority: PriorityDefault - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-root2": ".", "upstream-root2": ".",
}, },
description: "When root1 is not included in the update, it should be removed while root2 remains", description: "When root1 is not included in the update, it should be removed while root2 remains",
@@ -1262,7 +1253,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault, priority: PriorityDefault,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".", "upstream-root1": ".",
}, },
description: "When root2 is not included in the update, it should be removed while root1 remains", description: "When root2 is not included in the update, it should be removed while root1 remains",
@@ -1293,7 +1284,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 1, priority: PriorityDefault - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".", "upstream-root1": ".",
"upstream-root2": ".", "upstream-root2": ".",
"upstream-root3": ".", "upstream-root3": ".",
@@ -1326,7 +1317,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 2, priority: PriorityDefault - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".", "upstream-root1": ".",
"upstream-root2": ".", "upstream-root2": ".",
"upstream-root3": ".", "upstream-root3": ".",
@@ -1353,7 +1344,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
"upstream-other": "other.com", "upstream-other": "other.com",
}, },
@@ -1392,7 +1383,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
"upstream-group2": "example.com", "upstream-group2": "example.com",
"upstream-other": "other.com", "upstream-other": "other.com",
@@ -1448,7 +1439,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, muxEntry := range server.dnsMuxMap { for _, muxEntry := range server.dnsMuxMap {
if chainEntry.Handler == muxEntry.handler && if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority && chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) { chainEntry.Pattern.PunycodeString() == dns.Fqdn(muxEntry.domain.PunycodeString()) {
foundInMux = true foundInMux = true
break break
} }
@@ -1467,8 +1458,8 @@ func TestExtraDomains(t *testing.T) {
registerDomains []domain.List registerDomains []domain.List
deregisterDomains []domain.List deregisterDomains []domain.List
finalConfig nbdns.Config finalConfig nbdns.Config
expectedDomains []string expectedDomains domain.List
expectedMatchOnly []string expectedMatchOnly domain.List
applyHostConfigCall int applyHostConfigCall int
}{ }{
{ {
@@ -1482,12 +1473,12 @@ func TestExtraDomains(t *testing.T) {
{Domain: "config.example.com"}, {Domain: "config.example.com"},
}, },
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"extra1.example.com.", "extra1.example.com.",
"extra2.example.com.", "extra2.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra1.example.com.", "extra1.example.com.",
"extra2.example.com.", "extra2.example.com.",
}, },
@@ -1504,12 +1495,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{ registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"}, {"extra1.example.com", "extra2.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"extra1.example.com.", "extra1.example.com.",
"extra2.example.com.", "extra2.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra1.example.com.", "extra1.example.com.",
"extra2.example.com.", "extra2.example.com.",
}, },
@@ -1527,12 +1518,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{ registerDomains: []domain.List{
{"extra.example.com", "overlap.example.com"}, {"extra.example.com", "overlap.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"overlap.example.com.", "overlap.example.com.",
"extra.example.com.", "extra.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra.example.com.", "extra.example.com.",
}, },
applyHostConfigCall: 2, applyHostConfigCall: 2,
@@ -1552,12 +1543,12 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{ deregisterDomains: []domain.List{
{"extra1.example.com", "extra3.example.com"}, {"extra1.example.com", "extra3.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"extra2.example.com.", "extra2.example.com.",
"extra4.example.com.", "extra4.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra2.example.com.", "extra2.example.com.",
"extra4.example.com.", "extra4.example.com.",
}, },
@@ -1578,13 +1569,13 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{ deregisterDomains: []domain.List{
{"duplicate.example.com"}, {"duplicate.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"extra.example.com.", "extra.example.com.",
"other.example.com.", "other.example.com.",
"duplicate.example.com.", "duplicate.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra.example.com.", "extra.example.com.",
"other.example.com.", "other.example.com.",
"duplicate.example.com.", "duplicate.example.com.",
@@ -1609,13 +1600,13 @@ func TestExtraDomains(t *testing.T) {
{Domain: "newconfig.example.com"}, {Domain: "newconfig.example.com"},
}, },
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"newconfig.example.com.", "newconfig.example.com.",
"extra.example.com.", "extra.example.com.",
"duplicate.example.com.", "duplicate.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra.example.com.", "extra.example.com.",
"duplicate.example.com.", "duplicate.example.com.",
}, },
@@ -1636,12 +1627,12 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{ deregisterDomains: []domain.List{
{"protected.example.com"}, {"protected.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"extra.example.com.", "extra.example.com.",
"config.example.com.", "config.example.com.",
"protected.example.com.", "protected.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra.example.com.", "extra.example.com.",
}, },
applyHostConfigCall: 3, applyHostConfigCall: 3,
@@ -1652,7 +1643,7 @@ func TestExtraDomains(t *testing.T) {
ServiceEnable: true, ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{ NameServerGroups: []*nbdns.NameServerGroup{
{ {
Domains: []string{"ns.example.com", "overlap.ns.example.com"}, Domains: domain.List{"ns.example.com", "overlap.ns.example.com"},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("8.8.8.8"), IP: netip.MustParseAddr("8.8.8.8"),
@@ -1666,12 +1657,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{ registerDomains: []domain.List{
{"extra.example.com", "overlap.ns.example.com"}, {"extra.example.com", "overlap.ns.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"ns.example.com.", "ns.example.com.",
"overlap.ns.example.com.", "overlap.ns.example.com.",
"extra.example.com.", "extra.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"ns.example.com.", "ns.example.com.",
"overlap.ns.example.com.", "overlap.ns.example.com.",
"extra.example.com.", "extra.example.com.",
@@ -1742,8 +1733,8 @@ func TestExtraDomains(t *testing.T) {
lastConfig := capturedConfigs[len(capturedConfigs)-1] lastConfig := capturedConfigs[len(capturedConfigs)-1]
// Check all expected domains are present // Check all expected domains are present
domainMap := make(map[string]bool) domainMap := make(map[domain.Domain]bool)
matchOnlyMap := make(map[string]bool) matchOnlyMap := make(map[domain.Domain]bool)
for _, d := range lastConfig.Domains { for _, d := range lastConfig.Domains {
domainMap[d.Domain] = true domainMap[d.Domain] = true
@@ -1860,12 +1851,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
err := server.applyConfiguration(initialConfig) err := server.applyConfiguration(initialConfig)
assert.NoError(t, err) assert.NoError(t, err)
var domains []string var domains domain.List
for _, d := range capturedConfig.Domains { for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain) domains = append(domains, d.Domain)
} }
assert.Contains(t, domains, "config.example.com.") assert.Contains(t, domains, domain.Domain("config.example.com."))
assert.Contains(t, domains, "extra.example.com.") assert.Contains(t, domains, domain.Domain("extra.example.com."))
// Now apply a new configuration with overlapping domain // Now apply a new configuration with overlapping domain
updatedConfig := nbdns.Config{ updatedConfig := nbdns.Config{
@@ -1879,7 +1870,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Verify both domains are in config, but no duplicates // Verify both domains are in config, but no duplicates
domains = []string{} domains = domain.List{}
matchOnlyCount := 0 matchOnlyCount := 0
for _, d := range capturedConfig.Domains { for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain) domains = append(domains, d.Domain)
@@ -1888,12 +1879,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
} }
} }
assert.Contains(t, domains, "config.example.com.") assert.Contains(t, domains, domain.Domain("config.example.com."))
assert.Contains(t, domains, "extra.example.com.") assert.Contains(t, domains, domain.Domain("extra.example.com."))
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates") assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
// Extra domain should no longer be marked as match-only when in config // Extra domain should no longer be marked as match-only when in config
matchOnlyDomain := "" var matchOnlyDomain domain.Domain
for _, d := range capturedConfig.Domains { for _, d := range capturedConfig.Domains {
if d.Domain == "extra.example.com." && d.MatchOnly { if d.Domain == "extra.example.com." && d.MatchOnly {
matchOnlyDomain = d.Domain matchOnlyDomain = d.Domain
@@ -1946,10 +1937,10 @@ func TestDomainCaseHandling(t *testing.T) {
err := server.applyConfiguration(config) err := server.applyConfiguration(config)
assert.NoError(t, err) assert.NoError(t, err)
var domains []string var domains domain.List
for _, d := range capturedConfig.Domains { for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain) domains = append(domains, d.Domain)
} }
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, domain.Domain("config.example.com."), "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") assert.Contains(t, domains, domain.Domain("mixed.example.com."), "Mixed case domain should be normalized and present")
} }

View File

@@ -24,11 +24,15 @@ type ServiceViaMemory struct {
} }
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{ s := &ServiceViaMemory{
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(), runtimeIP: lastIP.String(),
runtimePort: defaultPort, runtimePort: defaultPort,
} }
return s return s
@@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
} }
firstLayerDecoder := layers.LayerTypeIPv4 firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil { if s.wgInterface.Address().IP.Is6() {
firstLayerDecoder = layers.LayerTypeIPv6 firstLayerDecoder = layers.LayerTypeIPv6
} }

View File

@@ -1,33 +0,0 @@
package dns
import (
"net"
"testing"
nbnet "github.com/netbirdio/netbird/util/net"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -117,15 +117,15 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
continue continue
} }
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dConf.Domain, Domain: dConf.Domain.PunycodeString(),
MatchOnly: dConf.MatchOnly, MatchOnly: dConf.MatchOnly,
}) })
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain) matchDomains = append(matchDomains, dConf.Domain.PunycodeString())
continue continue
} }
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
} }
if config.RouteAll { if config.RouteAll {

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
) )
const ( const (
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []string
domain string domain domain.Domain
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
successCount atomic.Int32 successCount atomic.Int32
@@ -62,7 +63,7 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status statusRecorder *peer.Status
} }
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain domain.Domain) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{ return &upstreamResolverBase{

View File

@@ -3,12 +3,14 @@ package dns
import ( import (
"context" "context"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -23,11 +25,11 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ net.IP, _ netip.Addr,
_ *net.IPNet, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string, domain domain.Domain,
) (*upstreamResolver, error) { ) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
c := &upstreamResolver{ c := &upstreamResolver{

View File

@@ -4,12 +4,13 @@ package dns
import ( import (
"context" "context"
"net" "net/netip"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
) )
type upstreamResolver struct { type upstreamResolver struct {
@@ -19,11 +20,11 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ net.IP, _ netip.Addr,
_ *net.IPNet, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain domain.Domain,
) (*upstreamResolver, error) { ) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{ nonIOS := &upstreamResolver{

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@@ -14,23 +15,24 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
) )
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP net.IP lIP netip.Addr
lNet *net.IPNet lNet netip.Prefix
interfaceName string interfaceName string
} }
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
interfaceName string, interfaceName string,
ip net.IP, ip netip.Addr,
net *net.IPNet, net netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain domain.Domain,
) (*upstreamResolverIOS, error) { ) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
@@ -58,8 +60,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} }
client.DialTimeout = timeout client.DialTimeout = timeout
upstreamIP := net.ParseIP(upstreamHost) upstreamIP, err := netip.ParseAddr(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil { if err != nil {
@@ -73,7 +78,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName) index, err := getInterfaceIndex(interfaceName)
if err != nil { if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err) log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
@@ -82,7 +87,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: ip, IP: ip.AsSlice(),
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: dialTimeout, Timeout: dialTimeout,

View File

@@ -2,7 +2,7 @@ package dns
import ( import (
"context" "context"
"net" "net/netip"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {

View File

@@ -121,8 +121,8 @@ type EngineConfig struct {
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool
BlockLANAccess bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
} }
@@ -359,6 +359,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("new wg interface: %w", err) return fmt.Errorf("new wg interface: %w", err)
} }
e.wgInterface = wgIface e.wgInterface = wgIface
e.statusRecorder.SetWgIface(wgIface)
// start flow manager right after interface creation // start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey() publicKey := e.config.WgPrivateKey.PublicKey()
@@ -380,7 +381,6 @@ func (e *Engine) Start() error {
return fmt.Errorf("run rosenpass manager: %w", err) return fmt.Errorf("run rosenpass manager: %w", err)
} }
} }
e.stateManager.Start() e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer() initialRoutes, dnsServer, err := e.newDnsServer()
@@ -431,7 +431,8 @@ func (e *Engine) Start() error {
return fmt.Errorf("up wg interface: %w", err) return fmt.Errorf("up wg interface: %w", err)
} }
if e.firewall != nil { // if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall) e.acl = acl.NewDefaultManager(e.firewall)
} }
@@ -487,11 +488,9 @@ func (e *Engine) createFirewall() error {
} }
func (e *Engine) initFirewall() error { func (e *Engine) initFirewall() error {
if e.firewall.IsServerRouteSupported() { if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { e.close()
e.close() return fmt.Errorf("enable server router: %w", err)
return fmt.Errorf("enable server router: %w", err)
}
} }
if e.config.BlockLANAccess { if e.config.BlockLANAccess {
@@ -525,6 +524,11 @@ func (e *Engine) initFirewall() error {
} }
func (e *Engine) blockLanAccess() { func (e *Engine) blockLanAccess() {
if e.config.BlockInbound {
// no need to set up extra deny rules if inbound is already blocked in general
return
}
var merr *multierror.Error var merr *multierror.Error
// TODO: keep this updated // TODO: keep this updated
@@ -796,56 +800,58 @@ func isNil(server nbssh.Server) bool {
} }
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Infof("SSH server is disabled because inbound connections are blocked")
return nil
}
if !e.config.ServerSSHAllowed { if !e.config.ServerSSHAllowed {
log.Warnf("running SSH server is not permitted") log.Info("SSH server is not enabled")
return nil return nil
} else {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
return nil
} }
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
return nil
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
@@ -988,12 +994,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update routes: %v", err)
} }
if e.acl != nil { if e.acl != nil {
@@ -1055,15 +1070,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers()) excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(excludedLazyPeers) e.connMgr.SetExcludeList(excludedLazyPeers)
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage. // Test received (upstream) servers for availability right away instead of upon usage.
@@ -1098,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix.Masked(),
Domains: domain.FromPunycodeList(protoRoute.Domains), Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID), NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
@@ -1132,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
return entries return entries
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(), ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0), CustomZones: make([]nbdns.CustomZone, 0),
@@ -1159,7 +1165,7 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.C
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() { for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
dnsNSGroup := &nbdns.NameServerGroup{ dnsNSGroup := &nbdns.NameServerGroup{
Primary: nsGroup.GetPrimary(), Primary: nsGroup.GetPrimary(),
Domains: nsGroup.GetDomains(), Domains: domain.FromPunycodeList(nsGroup.GetDomains()),
SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(), SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(),
} }
for _, ns := range nsGroup.GetNameServers() { for _, ns := range nsGroup.GetNameServers() {
@@ -1447,6 +1453,7 @@ func (e *Engine) close() {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
} }
e.wgInterface = nil e.wgInterface = nil
e.statusRecorder.SetWgIface(nil)
} }
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
@@ -1671,7 +1678,7 @@ func (e *Engine) RunHealthProbes() bool {
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append( return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)..., relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
) )
} }
@@ -1784,9 +1791,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
} }
// GetWgAddr returns the wireguard address // GetWgAddr returns the wireguard address
func (e *Engine) GetWgAddr() net.IP { func (e *Engine) GetWgAddr() netip.Addr {
if e.wgInterface == nil { if e.wgInterface == nil {
return nil return netip.Addr{}
} }
return e.wgInterface.Address().IP return e.wgInterface.Address().IP
} }
@@ -1796,6 +1803,10 @@ func (e *Engine) updateDNSForwarder(
enabled bool, enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry, fwdEntries []*dnsfwd.ForwarderEntry,
) { ) {
if e.config.DisableServerRoutes {
return
}
if !enabled { if !enabled {
if e.dnsForwardMgr == nil { if e.dnsForwardMgr == nil {
return return
@@ -1851,12 +1862,7 @@ func (e *Engine) Address() (netip.Addr, error) {
return netip.Addr{}, errors.New("wireguard interface not initialized") return netip.Addr{}, errors.New("wireguard interface not initialized")
} }
addr := e.wgInterface.Address() return e.wgInterface.Address().IP, nil
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
}
return ip.Unmap(), nil
} }
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {

View File

@@ -44,6 +44,7 @@ import (
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
mgmt "github.com/netbirdio/netbird/management/client" mgmt "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@@ -77,7 +78,7 @@ var (
type MockWGIface struct { type MockWGIface struct {
CreateFunc func() error CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error CreateOnAndroidFunc func(routeRange []string, ip string, domains domain.List) error
IsUserspaceBindFunc func() bool IsUserspaceBindFunc func() bool
NameFunc func() string NameFunc func() string
AddressFunc func() wgaddr.Address AddressFunc func() wgaddr.Address
@@ -99,6 +100,10 @@ type MockWGIface struct {
GetNetFunc func() *netstack.Net GetNetFunc func() *netstack.Net
} }
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc() return m.GetInterfaceGUIDStringFunc()
} }
@@ -107,7 +112,7 @@ func (m *MockWGIface) Create() error {
return m.CreateFunc() return m.CreateFunc()
} }
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error { func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains domain.List) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains) return m.CreateOnAndroidFunc(routeRange, ip, domains)
} }
@@ -371,11 +376,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {

View File

@@ -14,11 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/management/domain"
) )
type wgIfaceBase interface { type wgIfaceBase interface {
Create() error Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error CreateOnAndroid(routeRange []string, ip string, domains domain.List) error
IsUserspaceBind() bool IsUserspaceBind() bool
Name() string Name() string
Address() wgaddr.Address Address() wgaddr.Address
@@ -37,4 +38,5 @@ type wgIfaceBase interface {
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
} }

View File

@@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
// fallback if mark rules are not in place // fallback if mark rules are not in place
wgnet := c.iface.Address().Network wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice()) return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
} }
// mapRxPackets maps packet counts to RX based on flow direction // mapRxPackets maps packet counts to RX based on flow direction
@@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
// fallback if marks are not set // fallback if marks are not set
wgaddr := c.iface.Address().IP wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
switch { switch {
case wgaddr.Equal(src): case wgaddr == srcIP:
return nftypes.Egress return nftypes.Egress
case wgaddr.Equal(dst): case wgaddr == dstIP:
return nftypes.Ingress return nftypes.Ingress
case wgnetwork.Contains(src): case wgnetwork.Contains(srcIP):
// netbird network -> resource network // netbird network -> resource network
return nftypes.Ingress return nftypes.Ingress
case wgnetwork.Contains(dst): case wgnetwork.Contains(dstIP):
// resource network -> netbird network // resource network -> netbird network
return nftypes.Egress return nftypes.Egress
} }

View File

@@ -2,7 +2,7 @@ package logger
import ( import (
"context" "context"
"net" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -23,17 +23,16 @@ type Logger struct {
rcvChan atomic.Pointer[rcvChan] rcvChan atomic.Pointer[rcvChan]
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgIfaceIPNet net.IPNet wgIfaceNet netip.Prefix
dnsCollection atomic.Bool dnsCollection atomic.Bool
exitNodeCollection atomic.Bool exitNodeCollection atomic.Bool
Store types.Store Store types.Store
} }
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger { func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
return &Logger{ return &Logger{
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgIfaceIPNet: wgIfaceIPNet, wgIfaceNet: wgIfaceIPNet,
Store: store.NewMemoryStore(), Store: store.NewMemoryStore(),
} }
} }
@@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
var isSrcExitNode bool var isSrcExitNode bool
var isDestExitNode bool var isDestExitNode bool
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) { if !l.wgIfaceNet.Contains(event.SourceIP) {
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
} }
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) { if !l.wgIfaceNet.Contains(event.DestIP) {
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
} }

View File

@@ -1,7 +1,7 @@
package logger_test package logger_test
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -12,7 +12,7 @@ import (
) )
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
logger := logger.New(nil, net.IPNet{}) logger := logger.New(nil, netip.Prefix{})
logger.Enable() logger.Enable()
event := types.EventFields{ event := types.EventFields{

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net" "net/netip"
"runtime" "runtime"
"sync" "sync"
"time" "time"
@@ -34,11 +34,11 @@ type Manager struct {
// NewManager creates a new netflow manager // NewManager creates a new netflow manager
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var ipNet net.IPNet var prefix netip.Prefix
if iface != nil { if iface != nil {
ipNet = *iface.Address().Network prefix = iface.Address().Network
} }
flowLogger := logger.New(statusRecorder, ipNet) flowLogger := logger.New(statusRecorder, prefix)
var ct nftypes.ConnTracker var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {

View File

@@ -1,7 +1,7 @@
package netflow package netflow
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
func TestManager_Update(t *testing.T) { func TestManager_Update(t *testing.T) {
mockIFace := &mockIFaceMapper{ mockIFace := &mockIFaceMapper{
address: wgaddr.Address{ address: wgaddr.Address{
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.1/32"),
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
}, },
isUserspaceBind: true, isUserspaceBind: true,
} }
@@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) {
func TestManager_Update_TokenPreservation(t *testing.T) { func TestManager_Update_TokenPreservation(t *testing.T) {
mockIFace := &mockIFaceMapper{ mockIFace := &mockIFaceMapper{
address: wgaddr.Address{ address: wgaddr.Address{
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.1/32"),
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
}, },
isUserspaceBind: true, isUserspaceBind: true,
} }

View File

@@ -1,7 +1,9 @@
package peer package peer
import ( import (
"context"
"errors" "errors"
"fmt"
"net/netip" "net/netip"
"slices" "slices"
"sync" "sync"
@@ -31,6 +33,10 @@ type ResolvedDomainInfo struct {
ParentDomain domain.Domain ParentDomain domain.Domain
} }
type WGIfaceStatus interface {
FullStats() (*configurer.Stats, error)
}
type EventListener interface { type EventListener interface {
OnEvent(event *proto.SystemEvent) OnEvent(event *proto.SystemEvent)
} }
@@ -146,11 +152,31 @@ type FullStatus struct {
LazyConnectionEnabled bool LazyConnectionEnabled bool
} }
type StatusChangeSubscription struct {
peerID string
id string
eventsChan chan struct{}
ctx context.Context
}
func newStatusChangeSubscription(ctx context.Context, peerID string) *StatusChangeSubscription {
return &StatusChangeSubscription{
ctx: ctx,
peerID: peerID,
id: uuid.New().String(),
eventsChan: make(chan struct{}, 1),
}
}
func (s *StatusChangeSubscription) Events() chan struct{} {
return s.eventsChan
}
// Status holds a state of peers, signal, management connections and relays // Status holds a state of peers, signal, management connections and relays
type Status struct { type Status struct {
mux sync.Mutex mux sync.Mutex
peers map[string]State peers map[string]State
changeNotify map[string]chan struct{} changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
signalState bool signalState bool
signalError error signalError error
managementState bool managementState bool
@@ -181,13 +207,14 @@ type Status struct {
ingressGwMgr *ingressgw.Manager ingressGwMgr *ingressgw.Manager
routeIDLookup routeIDLookup routeIDLookup routeIDLookup
wgIface WGIfaceStatus
} }
// NewRecorder returns a new Status instance // NewRecorder returns a new Status instance
func NewRecorder(mgmAddress string) *Status { func NewRecorder(mgmAddress string) *Status {
return &Status{ return &Status{
peers: make(map[string]State), peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}), changeNotify: make(map[string]map[string]*StatusChangeSubscription),
eventStreams: make(map[string]chan *proto.SystemEvent), eventStreams: make(map[string]chan *proto.SystemEvent),
eventQueue: NewEventQueue(eventQueueSize), eventQueue: NewEventQueue(eventQueueSize),
offlinePeers: make([]State, 0), offlinePeers: make([]State, 0),
@@ -289,11 +316,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
if receivedState.IP != "" { oldState := peerState.ConnStatus
peerState.IP = receivedState.IP
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus { if receivedState.ConnStatus != peerState.ConnStatus {
peerState.ConnStatus = receivedState.ConnStatus peerState.ConnStatus = receivedState.ConnStatus
@@ -309,11 +332,14 @@ func (d *Status) UpdatePeerState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if skipNotification { if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
return nil d.notifyPeerListChanged()
} }
d.notifyPeerListChanged() // when we close the connection we will not notify the router manager
if receivedState.ConnStatus == StatusIdle {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil return nil
} }
@@ -380,11 +406,8 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
if receivedState.IP != "" { oldState := peerState.ConnStatus
peerState.IP = receivedState.IP oldIsRelayed := peerState.Relayed
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
@@ -397,12 +420,13 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if skipNotification { if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
return nil d.notifyPeerListChanged()
} }
d.notifyPeerStateChangeListeners(receivedState.PubKey) if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerListChanged() d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil return nil
} }
@@ -415,7 +439,8 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) oldState := peerState.ConnStatus
oldIsRelayed := peerState.Relayed
peerState.ConnStatus = receivedState.ConnStatus peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
@@ -425,12 +450,13 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if skipNotification { if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
return nil d.notifyPeerListChanged()
} }
d.notifyPeerStateChangeListeners(receivedState.PubKey) if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerListChanged() d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil return nil
} }
@@ -443,7 +469,8 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) oldState := peerState.ConnStatus
oldIsRelayed := peerState.Relayed
peerState.ConnStatus = receivedState.ConnStatus peerState.ConnStatus = receivedState.ConnStatus
peerState.Relayed = receivedState.Relayed peerState.Relayed = receivedState.Relayed
@@ -452,12 +479,13 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if skipNotification { if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
return nil d.notifyPeerListChanged()
} }
d.notifyPeerStateChangeListeners(receivedState.PubKey) if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerListChanged() d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil return nil
} }
@@ -470,7 +498,8 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) oldState := peerState.ConnStatus
oldIsRelayed := peerState.Relayed
peerState.ConnStatus = receivedState.ConnStatus peerState.ConnStatus = receivedState.ConnStatus
peerState.Relayed = receivedState.Relayed peerState.Relayed = receivedState.Relayed
@@ -482,12 +511,13 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if skipNotification { if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
return nil d.notifyPeerListChanged()
} }
d.notifyPeerStateChangeListeners(receivedState.PubKey) if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerListChanged() d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
return nil return nil
} }
@@ -510,17 +540,12 @@ func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGSt
return nil return nil
} }
func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool { func hasStatusOrRelayedChange(oldConnStatus, newConnStatus ConnStatus, oldRelayed, newRelayed bool) bool {
switch { return oldRelayed != newRelayed || hasConnStatusChanged(newConnStatus, oldConnStatus)
case receivedConnStatus == StatusConnecting: }
return true
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusConnecting: func hasConnStatusChanged(oldStatus, newStatus ConnStatus) bool {
return true return newStatus != oldStatus
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusIdle:
return curr.IP != ""
default:
return false
}
} }
// UpdatePeerFQDN update peer's state fqdn only // UpdatePeerFQDN update peer's state fqdn only
@@ -553,19 +578,41 @@ func (d *Status) FinishPeerListModifications() {
d.notifyPeerListChanged() d.notifyPeerListChanged()
} }
// GetPeerStateChangeNotifier returns a change notifier channel for a peer func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
ch, found := d.changeNotify[peer] sub := newStatusChangeSubscription(ctx, peerID)
if found { if _, ok := d.changeNotify[peerID]; !ok {
return ch d.changeNotify[peerID] = make(map[string]*StatusChangeSubscription)
}
d.changeNotify[peerID][sub.id] = sub
return sub
}
func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscription) {
d.mux.Lock()
defer d.mux.Unlock()
if subscription == nil {
return
} }
ch = make(chan struct{}) channels, ok := d.changeNotify[subscription.peerID]
d.changeNotify[peer] = ch if !ok {
return ch return
}
sub, exists := channels[subscription.id]
if !exists {
return
}
delete(channels, subscription.id)
if len(channels) == 0 {
delete(d.changeNotify, sub.peerID)
}
} }
// GetLocalPeerState returns the local peer state // GetLocalPeerState returns the local peer state
@@ -940,13 +987,20 @@ func (d *Status) onConnectionChanged() {
// notifyPeerStateChangeListeners notifies route manager about the change in peer state // notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) { func (d *Status) notifyPeerStateChangeListeners(peerID string) {
ch, found := d.changeNotify[peerID] subs, ok := d.changeNotify[peerID]
if !found { if !ok {
return return
} }
for _, sub := range subs {
close(ch) // block the write because we do not want to miss notification
delete(d.changeNotify, peerID) // must have to be sure we will run the GetPeerState() on separated thread
go func() {
select {
case sub.eventsChan <- struct{}{}:
case <-sub.ctx.Done():
}
}()
}
} }
func (d *Status) notifyPeerListChanged() { func (d *Status) notifyPeerListChanged() {
@@ -1030,6 +1084,23 @@ func (d *Status) GetEventHistory() []*proto.SystemEvent {
return d.eventQueue.GetAll() return d.eventQueue.GetAll()
} }
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
d.mux.Lock()
defer d.mux.Unlock()
d.wgIface = wgInterface
}
func (d *Status) PeersStatus() (*configurer.Stats, error) {
d.mux.Lock()
defer d.mux.Unlock()
if d.wgIface == nil {
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
}
return d.wgIface.FullStats()
}
type EventQueue struct { type EventQueue struct {
maxSize int maxSize int
events []*proto.SystemEvent events []*proto.SystemEvent

View File

@@ -1,9 +1,11 @@
package peer package peer
import ( import (
"context"
"errors" "errors"
"sync" "sync"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -42,16 +44,16 @@ func TestGetPeer(t *testing.T) {
func TestUpdatePeerState(t *testing.T) { func TestUpdatePeerState(t *testing.T) {
key := "abc" key := "abc"
ip := "10.10.10.10" ip := "10.10.10.10"
fqdn := "peer-a.netbird.local"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
_ = status.AddPeer(key, fqdn, ip)
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex), ConnStatusUpdate: time.Now(),
ConnStatus: StatusConnecting,
} }
status.peers[key] = peerState
peerState.IP = ip
err := status.UpdatePeerState(peerState) err := status.UpdatePeerState(peerState)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
@@ -83,25 +85,27 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
key := "abc" key := "abc"
ip := "10.10.10.10" ip := "10.10.10.10"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
_ = status.AddPeer(key, "abc.netbird", ip)
sub := status.SubscribeToPeerStateChanges(context.Background(), key)
assert.NotNil(t, sub, "channel shouldn't be nil")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex), ConnStatus: StatusConnecting,
Relayed: false,
ConnStatusUpdate: time.Now(),
} }
status.peers[key] = peerState
ch := status.GetPeerStateChangeNotifier(key)
assert.NotNil(t, ch, "channel shouldn't be nil")
peerState.IP = ip
err := status.UpdatePeerRelayedStateToDisconnected(peerState) err := status.UpdatePeerRelayedStateToDisconnected(peerState)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
select { select {
case <-ch: case <-sub.eventsChan:
default: case <-timeoutCtx.Done():
t.Errorf("channel wasn't closed after update") t.Errorf("timed out waiting for event")
} }
} }

View File

@@ -170,7 +170,7 @@ func ProbeAll(
var wg sync.WaitGroup var wg sync.WaitGroup
for i, uri := range relays { for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second) ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
defer cancel() defer cancel()
wg.Add(1) wg.Add(1)

View File

@@ -1,4 +1,4 @@
package routemanager package client
import ( import (
"context" "context"
@@ -7,10 +7,8 @@ import (
"runtime" "runtime"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
@@ -36,6 +34,7 @@ const (
reasonRouteUpdate reasonRouteUpdate
reasonPeerUpdate reasonPeerUpdate
reasonShutdown reasonShutdown
reasonHA
) )
type routerPeerStatus struct { type routerPeerStatus struct {
@@ -44,9 +43,9 @@ type routerPeerStatus struct {
latency time.Duration latency time.Duration
} }
type routesUpdate struct { type RoutesUpdate struct {
updateSerial uint64 UpdateSerial uint64
routes []*route.Route Routes []*route.Route
} }
// RouteHandler defines the interface for handling routes // RouteHandler defines the interface for handling routes
@@ -58,64 +57,54 @@ type RouteHandler interface {
RemoveAllowedIPs() error RemoveAllowedIPs() error
} }
type clientNetwork struct { type WatcherConfig struct {
Context context.Context
DNSRouteInterval time.Duration
WGInterface iface.WGIface
StatusRecorder *peer.Status
Route *route.Route
Handler RouteHandler
}
// Watcher watches route and peer changes and updates allowed IPs accordingly.
// Once stopped, it cannot be reused.
type Watcher struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface iface.WGIface wgInterface iface.WGIface
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
routeUpdate chan routesUpdate routeUpdate chan RoutesUpdate
peerStateUpdate chan struct{} peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{} routePeersNotifiers map[string]chan struct{} // map of peer key to channel for peer state changes
currentChosen *route.Route currentChosen *route.Route
handler RouteHandler handler RouteHandler
updateSerial uint64 updateSerial uint64
} }
func newClientNetworkWatcher( func NewWatcher(config WatcherConfig) *Watcher {
ctx context.Context, ctx, cancel := context.WithCancel(config.Context)
dnsRouteInterval time.Duration,
wgInterface iface.WGIface,
statusRecorder *peer.Status,
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &Watcher{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
statusRecorder: statusRecorder, statusRecorder: config.StatusRecorder,
wgInterface: wgInterface, wgInterface: config.WGInterface,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan RoutesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute( handler: config.Handler,
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouteInterval,
statusRecorder,
wgInterface,
dnsServer,
peerStore,
useNewDNSRoute,
),
} }
return client return client
} }
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { func (w *Watcher) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses := make(map[route.ID]routerPeerStatus) routePeerStatuses := make(map[route.ID]routerPeerStatus)
for _, r := range c.routes { for _, r := range w.routes {
peerStatus, err := c.statusRecorder.GetPeer(r.Peer) peerStatus, err := w.statusRecorder.GetPeer(r.Peer)
if err != nil { if err != nil {
log.Debugf("couldn't fetch peer state: %v", err) log.Debugf("couldn't fetch peer state %v: %v", r.Peer, err)
continue continue
} }
routePeerStatuses[r.ID] = routerPeerStatus{ routePeerStatuses[r.ID] = routerPeerStatus{
@@ -128,7 +117,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
} }
// getBestRouteFromStatuses determines the most optimal route from the available routes // getBestRouteFromStatuses determines the most optimal route from the available routes
// within a clientNetwork, taking into account peer connection status, route metrics, and // within a Watcher, taking into account peer connection status, route metrics, and
// preference for non-relayed and direct connections. // preference for non-relayed and direct connections.
// //
// It follows these prioritization rules: // It follows these prioritization rules:
@@ -140,17 +129,17 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
// * Stability: In case of equal scores, the currently active route (if any) is maintained. // * Stability: In case of equal scores, the currently active route (if any) is maintained.
// //
// It returns the ID of the selected optimal route. // It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
chosen := route.ID("") var chosen route.ID
chosenScore := float64(0) chosenScore := float64(0)
currScore := float64(0) currScore := float64(0)
currID := route.ID("") var currID route.ID
if c.currentChosen != nil { if w.currentChosen != nil {
currID = c.currentChosen.ID currID = w.currentChosen.ID
} }
for _, r := range c.routes { for _, r := range w.routes {
tempScore := float64(0) tempScore := float64(0)
peerStatus, found := routePeerStatuses[r.ID] peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected { if !found || !peerStatus.connected {
@@ -167,7 +156,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
if peerStatus.latency != 0 { if peerStatus.latency != 0 {
latency = peerStatus.latency latency = peerStatus.latency
} else { } else {
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler) log.Tracef("peer %s has 0 latency, range %s", r.Peer, w.handler)
} }
// avoid negative tempScore on the higher latency calculation // avoid negative tempScore on the higher latency calculation
@@ -197,149 +186,145 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
} }
} }
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore) chosenID := chosen
if chosen == "" {
chosenID = "<none>"
}
currentID := currID
if currID == "" {
currentID = "<none>"
}
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosenID, chosenScore, currentID, currScore)
switch { switch {
case chosen == "": case chosen == "":
var peers []string var peers []string
for _, r := range c.routes { for _, r := range w.routes {
peers = append(peers, r.Peer) peers = append(peers, r.Peer)
} }
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers) log.Infof("network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", w.handler, peers)
case chosen != currID: case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes // we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore { if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore) log.Debugf("keeping current routing peer %s for [%v]: the score difference with latency is less than 0.01(10ms): current: %f, new: %f",
w.currentChosen.Peer, w.handler, currScore, chosenScore)
return currID return currID
} }
var p string var p string
if rt := c.routes[chosen]; rt != nil { if rt := w.routes[chosen]; rt != nil {
p = rt.Peer p = rt.Peer
} }
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler) log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, w.handler)
} }
return chosen return chosen
} }
func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) { func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
subscription := w.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey)
defer w.statusRecorder.UnsubscribePeerStateChanges(subscription)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-closer: case <-closer:
return return
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey): case <-subscription.Events():
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil || state.ConnStatus == peer.StatusConnecting {
continue
}
peerStateUpdate <- struct{}{} peerStateUpdate <- struct{}{}
log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus) log.Debugf("triggered route state update for Peer: %s", peerKey)
} }
} }
} }
func (c *clientNetwork) startPeersStatusChangeWatcher() { func (w *Watcher) startNewPeerStatusWatchers() {
for _, r := range c.routes { for _, r := range w.routes {
_, found := c.routePeersNotifiers[r.Peer] if _, found := w.routePeersNotifiers[r.Peer]; found {
if found {
continue continue
} }
closerChan := make(chan struct{}) closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan w.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan) go w.watchPeerStatusChanges(w.ctx, r.Peer, w.peerStateUpdate, closerChan)
} }
} }
func (c *clientNetwork) removeRouteFromWireGuardPeer() error { // addAllowedIPs adds the allowed IPs for the current chosen route to the handler.
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil { func (w *Watcher) addAllowedIPs(route *route.Route) error {
if err := w.handler.AddAllowedIPs(route.Peer); err != nil {
return fmt.Errorf("add allowed IPs for peer %s: %w", route.Peer, err)
}
if err := w.statusRecorder.AddPeerStateRoute(route.Peer, w.handler.String(), route.GetResourceID()); err != nil {
log.Warnf("Failed to update peer state: %v", err) log.Warnf("Failed to update peer state: %v", err)
} }
if err := c.handler.RemoveAllowedIPs(); err != nil { w.connectEvent(route)
return fmt.Errorf("remove allowed IPs: %w", err)
}
return nil return nil
} }
func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error { func (w *Watcher) removeAllowedIPs(route *route.Route, rsn reason) error {
if c.currentChosen == nil { if err := w.statusRecorder.RemovePeerStateRoute(route.Peer, w.handler.String()); err != nil {
return nil log.Warnf("Failed to update peer state: %v", err)
} }
var merr *multierror.Error if err := w.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
if err := c.removeRouteFromWireGuardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
} }
c.disconnectEvent(rsn) w.disconnectEvent(route, rsn)
return nberrors.FormatErrorOrNil(merr) return nil
} }
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error { func (w *Watcher) recalculateRoutes(rsn reason) error {
routerPeerStatuses := c.getRouterPeerStatuses() routerPeerStatuses := w.getRouterPeerStatuses()
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses) newChosenID := w.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system // If no route is chosen, remove the route from the peer
if newChosenID == "" { if newChosenID == "" {
if err := c.removeRouteFromPeerAndSystem(rsn); err != nil { if w.currentChosen == nil {
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err) return nil
} }
c.currentChosen = nil if err := w.removeAllowedIPs(w.currentChosen, rsn); err != nil {
return fmt.Errorf("remove obsolete: %w", err)
}
w.currentChosen = nil
return nil return nil
} }
// If the chosen route is the same as the current route, do nothing // If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID && if w.currentChosen != nil && w.currentChosen.ID == newChosenID &&
c.currentChosen.Equal(c.routes[newChosenID]) { w.currentChosen.Equal(w.routes[newChosenID]) {
return nil return nil
} }
var isNew bool // If the chosen route was assigned to a different peer, remove the allowed IPs first
if c.currentChosen == nil { if isNew := w.currentChosen == nil; !isNew {
// If they were not previously assigned to another peer, add routes to the system first if err := w.removeAllowedIPs(w.currentChosen, reasonHA); err != nil {
if err := c.handler.AddRoute(c.ctx); err != nil { return fmt.Errorf("remove old: %w", err)
return fmt.Errorf("add route: %w", err)
}
isNew = true
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
} }
} }
c.currentChosen = c.routes[newChosenID] newChosenRoute := w.routes[newChosenID]
if err := w.addAllowedIPs(newChosenRoute); err != nil {
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil { return fmt.Errorf("add new: %w", err)
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
} }
if isNew { w.currentChosen = newChosenRoute
c.connectEvent()
}
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}
return nil return nil
} }
func (c *clientNetwork) connectEvent() { func (w *Watcher) connectEvent(route *route.Route) {
var defaultRoute bool var defaultRoute bool
for _, r := range c.routes { for _, r := range w.routes {
if r.Network.Bits() == 0 { if r.Network.Bits() == 0 {
defaultRoute = true defaultRoute = true
break break
@@ -351,13 +336,13 @@ func (c *clientNetwork) connectEvent() {
} }
meta := map[string]string{ meta := map[string]string{
"network": c.handler.String(), "network": w.handler.String(),
} }
if c.currentChosen != nil { if route != nil {
meta["id"] = string(c.currentChosen.NetID) meta["id"] = string(route.NetID)
meta["peer"] = c.currentChosen.Peer meta["peer"] = route.Peer
} }
c.statusRecorder.PublishEvent( w.statusRecorder.PublishEvent(
proto.SystemEvent_INFO, proto.SystemEvent_INFO,
proto.SystemEvent_NETWORK, proto.SystemEvent_NETWORK,
"Default route added", "Default route added",
@@ -366,9 +351,9 @@ func (c *clientNetwork) connectEvent() {
) )
} }
func (c *clientNetwork) disconnectEvent(rsn reason) { func (w *Watcher) disconnectEvent(route *route.Route, rsn reason) {
var defaultRoute bool var defaultRoute bool
for _, r := range c.routes { for _, r := range w.routes {
if r.Network.Bits() == 0 { if r.Network.Bits() == 0 {
defaultRoute = true defaultRoute = true
break break
@@ -384,11 +369,11 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
var userMessage string var userMessage string
meta := make(map[string]string) meta := make(map[string]string)
if c.currentChosen != nil { if route != nil {
meta["id"] = string(c.currentChosen.NetID) meta["id"] = string(route.NetID)
meta["peer"] = c.currentChosen.Peer meta["peer"] = route.Peer
} }
meta["network"] = c.handler.String() meta["network"] = w.handler.String()
switch rsn { switch rsn {
case reasonShutdown: case reasonShutdown:
severity = proto.SystemEvent_INFO severity = proto.SystemEvent_INFO
@@ -401,13 +386,17 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
severity = proto.SystemEvent_WARNING severity = proto.SystemEvent_WARNING
message = "Default route disconnected due to peer unreachability" message = "Default route disconnected due to peer unreachability"
userMessage = "Exit node connection lost. Your internet access might be affected." userMessage = "Exit node connection lost. Your internet access might be affected."
case reasonHA:
severity = proto.SystemEvent_INFO
message = "Default route disconnected due to high availability change"
userMessage = "Exit node disconnected due to high availability change."
default: default:
severity = proto.SystemEvent_ERROR severity = proto.SystemEvent_ERROR
message = "Default route disconnected for unknown reasons" message = "Default route disconnected for unknown reasons"
userMessage = "Exit node disconnected for unknown reasons." userMessage = "Exit node disconnected for unknown reasons."
} }
c.statusRecorder.PublishEvent( w.statusRecorder.PublishEvent(
severity, severity,
proto.SystemEvent_NETWORK, proto.SystemEvent_NETWORK,
message, message,
@@ -416,86 +405,101 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
) )
} }
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { func (w *Watcher) SendUpdate(update RoutesUpdate) {
go func() { go func() {
c.routeUpdate <- update select {
case w.routeUpdate <- update:
case <-w.ctx.Done():
}
}() }()
} }
func (c *clientNetwork) handleUpdate(update routesUpdate) bool { func (w *Watcher) classifyUpdate(update RoutesUpdate) bool {
isUpdateMapDifferent := false isUpdateMapDifferent := false
updateMap := make(map[route.ID]*route.Route) updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes { for _, r := range update.Routes {
updateMap[r.ID] = r updateMap[r.ID] = r
} }
if len(c.routes) != len(updateMap) { if len(w.routes) != len(updateMap) {
isUpdateMapDifferent = true isUpdateMapDifferent = true
} }
for id, r := range c.routes { for id, r := range w.routes {
_, found := updateMap[id] _, found := updateMap[id]
if !found { if !found {
close(c.routePeersNotifiers[r.Peer]) close(w.routePeersNotifiers[r.Peer])
delete(c.routePeersNotifiers, r.Peer) delete(w.routePeersNotifiers, r.Peer)
isUpdateMapDifferent = true isUpdateMapDifferent = true
continue continue
} }
if !reflect.DeepEqual(c.routes[id], updateMap[id]) { if !reflect.DeepEqual(w.routes[id], updateMap[id]) {
isUpdateMapDifferent = true isUpdateMapDifferent = true
} }
} }
c.routes = updateMap w.routes = updateMap
return isUpdateMapDifferent return isUpdateMapDifferent
} }
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events. // Start is the main point of reacting on client network routing events.
// All the processing related to the client network should be done here. Thread-safe. // All the processing related to the client network should be done here. Thread-safe.
func (c *clientNetwork) peersStateAndUpdateWatcher() { func (w *Watcher) Start() {
for { for {
select { select {
case <-c.ctx.Done(): case <-w.ctx.Done():
log.Debugf("Stopping watcher for network [%v]", c.handler)
if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
}
return return
case <-c.peerStateUpdate: case <-w.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate) if err := w.recalculateRoutes(reasonPeerUpdate); err != nil {
if err != nil { log.Errorf("Failed to recalculate routes for network [%v]: %v", w.handler, err)
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
} }
case update := <-c.routeUpdate: case update := <-w.routeUpdate:
if update.updateSerial < c.updateSerial { if update.UpdateSerial < w.updateSerial {
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial) log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", w.updateSerial, update.UpdateSerial)
continue continue
} }
log.Debugf("Received a new client network route update for [%v]", c.handler) w.handleRouteUpdate(update)
// hash update somehow
isTrueRouteUpdate := c.handleUpdate(update)
c.updateSerial = update.updateSerial
if isTrueRouteUpdate {
log.Debug("Client network update contains different routes, recalculating routes")
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate)
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
} else {
log.Debug("Route update is not different, skipping route recalculation")
}
c.startPeersStatusChangeWatcher()
} }
} }
} }
func handlerFromRoute( func (w *Watcher) handleRouteUpdate(update RoutesUpdate) {
log.Debugf("Received a new client network route update for [%v]", w.handler)
// hash update somehow
isTrueRouteUpdate := w.classifyUpdate(update)
w.updateSerial = update.UpdateSerial
if isTrueRouteUpdate {
log.Debugf("client network update %v for [%v] contains different routes, recalculating routes", update.UpdateSerial, w.handler)
if err := w.recalculateRoutes(reasonRouteUpdate); err != nil {
log.Errorf("failed to recalculate routes for network [%v]: %v", w.handler, err)
}
} else {
log.Debugf("route update %v for [%v] is not different, skipping route recalculation", update.UpdateSerial, w.handler)
}
w.startNewPeerStatusWatchers()
}
// Stop stops the watcher and cleans up resources.
func (w *Watcher) Stop() {
log.Debugf("Stopping watcher for network [%v]", w.handler)
w.cancel()
if w.currentChosen == nil {
return
}
if err := w.removeAllowedIPs(w.currentChosen, reasonShutdown); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", w.handler, err)
}
}
func HandlerFromRoute(
rt *route.Route, rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter, routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,

View File

@@ -1,4 +1,4 @@
package routemanager package client
import ( import (
"fmt" "fmt"
@@ -395,7 +395,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
} }
// create new clientNetwork // create new clientNetwork
client := &clientNetwork{ client := &Watcher{
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil), handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes, routes: tc.existingRoutes,
currentChosen: currentRoute, currentChosen: currentRoute,

View File

@@ -229,15 +229,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
} }
if len(r.Answer) > 0 && len(r.Question) > 0 { if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := "" var origPattern domain.Domain
if writer, ok := w.(*nbdns.ResponseWriterChain); ok { if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
origPattern = writer.GetOrigPattern() origPattern = writer.GetOrigPattern()
} }
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name)) resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
// already punycode via RegisterHandler() originalDomain := origPattern
originalDomain := domain.Domain(origPattern)
if originalDomain == "" { if originalDomain == "" {
originalDomain = resolvedDomain originalDomain = resolvedDomain
} }
@@ -264,7 +263,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
continue continue
} }
prefix := netip.PrefixFrom(ip, ip.BitLen()) prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
newPrefixes = append(newPrefixes, prefix) newPrefixes = append(newPrefixes, prefix)
} }

View File

@@ -11,9 +11,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
@@ -21,9 +23,11 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/server"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
@@ -68,9 +72,9 @@ type DefaultManager struct {
ctx context.Context ctx context.Context
stop context.CancelFunc stop context.CancelFunc
mux sync.Mutex mux sync.Mutex
clientNetworks map[route.HAUniqueID]*clientNetwork clientNetworks map[route.HAUniqueID]*client.Watcher
routeSelector *routeselector.RouteSelector routeSelector *routeselector.RouteSelector
serverRouter *serverRouter serverRouter *server.Router
sysOps *systemops.SysOps sysOps *systemops.SysOps
statusRecorder *peer.Status statusRecorder *peer.Status
relayMgr *relayClient.Manager relayMgr *relayClient.Manager
@@ -88,6 +92,7 @@ type DefaultManager struct {
useNewDNSRoute bool useNewDNSRoute bool
disableClientRoutes bool disableClientRoutes bool
disableServerRoutes bool disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler
} }
func NewManager(config ManagerConfig) *DefaultManager { func NewManager(config ManagerConfig) *DefaultManager {
@@ -99,7 +104,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
dnsRouteInterval: config.DNSRouteInterval, dnsRouteInterval: config.DNSRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork), clientNetworks: make(map[route.HAUniqueID]*client.Watcher),
relayMgr: config.RelayManager, relayMgr: config.RelayManager,
sysOps: sysOps, sysOps: sysOps,
statusRecorder: config.StatusRecorder, statusRecorder: config.StatusRecorder,
@@ -111,6 +116,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
peerStore: config.PeerStore, peerStore: config.PeerStore,
disableClientRoutes: config.DisableClientRoutes, disableClientRoutes: config.DisableClientRoutes,
disableServerRoutes: config.DisableServerRoutes, disableServerRoutes: config.DisableServerRoutes,
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
} }
useNoop := netstack.IsEnabled() || config.DisableClientRoutes useNoop := netstack.IsEnabled() || config.DisableClientRoutes
@@ -226,7 +232,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
} }
var err error var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil { if err != nil {
return err return err
} }
@@ -237,7 +243,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop() m.stop()
if m.serverRouter != nil { if m.serverRouter != nil {
m.serverRouter.cleanUp() m.serverRouter.CleanUp()
} }
if m.routeRefCounter != nil { if m.routeRefCounter != nil {
@@ -265,6 +271,54 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
toAdd := make(map[route.HAUniqueID]*route.Route)
toRemove := make(map[route.HAUniqueID]client.RouteHandler)
for id, routes := range newRoutes {
if len(routes) > 0 {
toAdd[id] = routes[0]
}
}
for id, activeHandler := range m.activeRoutes {
if _, exists := toAdd[id]; exists {
delete(toAdd, id)
} else {
toRemove[id] = activeHandler
}
}
var merr *multierror.Error
for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
}
delete(m.activeRoutes, id)
}
for id, route := range toAdd {
handler := client.HandlerFromRoute(
route,
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsRouteInterval,
m.statusRecorder,
m.wgInterface,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
if err := handler.AddRoute(m.ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
continue
}
m.activeRoutes[id] = handler
}
return nberrors.FormatErrorOrNil(merr)
}
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
@@ -279,22 +333,28 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
var merr *multierror.Error
if !m.disableClientRoutes { if !m.disableClientRoutes {
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
}
m.updateClientNetworks(updateSerial, filteredClientRoutes) m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes)
} }
m.clientRoutes = newClientRoutesIDMap m.clientRoutes = newClientRoutesIDMap
if m.serverRouter == nil { if m.serverRouter == nil {
return nil return nberrors.FormatErrorOrNil(merr)
} }
if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
return fmt.Errorf("update routes: %w", err) merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err))
} }
return nil return nberrors.FormatErrorOrNil(merr)
} }
// SetRouteChangeListener set RouteListener for route change Notifier // SetRouteChangeListener set RouteListener for route change Notifier
@@ -341,6 +401,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.notifier.OnNewRoutes(networks) m.notifier.OnNewRoutes(networks)
if err := m.updateSystemRoutes(networks); err != nil {
log.Errorf("failed to update system routes during selection: %v", err)
}
m.stopObsoleteClients(networks) m.stopObsoleteClients(networks)
for id, routes := range networks { for id, routes := range networks {
@@ -349,21 +413,24 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue continue
} }
clientNetworkWatcher := newClientNetworkWatcher( handler := m.activeRoutes[id]
m.ctx, if handler == nil {
m.dnsRouteInterval, log.Warnf("no active handler found for route %s", id)
m.wgInterface, continue
m.statusRecorder, }
routes[0],
m.routeRefCounter, config := client.WatcherConfig{
m.allowedIPsRefCounter, Context: m.ctx,
m.dnsServer, DNSRouteInterval: m.dnsRouteInterval,
m.peerStore, WGInterface: m.wgInterface,
m.useNewDNSRoute, StatusRecorder: m.statusRecorder,
) Route: routes[0],
Handler: handler,
}
clientNetworkWatcher := client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.Start()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
} }
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil { if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
@@ -375,8 +442,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) { func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks { for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok { if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id) client.Stop()
client.cancel()
delete(m.clientNetworks, id) delete(m.clientNetworks, id)
} }
} }
@@ -389,26 +455,29 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher( handler := m.activeRoutes[id]
m.ctx, if handler == nil {
m.dnsRouteInterval, log.Errorf("No active handler found for route %s", id)
m.wgInterface, continue
m.statusRecorder, }
routes[0],
m.routeRefCounter, config := client.WatcherConfig{
m.allowedIPsRefCounter, Context: m.ctx,
m.dnsServer, DNSRouteInterval: m.dnsRouteInterval,
m.peerStore, WGInterface: m.wgInterface,
m.useNewDNSRoute, StatusRecorder: m.statusRecorder,
) Route: routes[0],
Handler: handler,
}
clientNetworkWatcher = client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.Start()
} }
update := routesUpdate{ update := client.RoutesUpdate{
updateSerial: updateSerial, UpdateSerial: updateSerial,
routes: routes, Routes: routes,
} }
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) clientNetworkWatcher.SendUpdate(update)
} }
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime"
"testing" "testing"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
@@ -45,7 +44,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -72,7 +71,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.252.250/30"), Network: netip.MustParsePrefix("100.64.252.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -100,7 +99,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"), Network: netip.MustParsePrefix("100.64.30.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -128,7 +127,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"), Network: netip.MustParsePrefix("100.64.30.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -212,7 +211,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -234,7 +233,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -251,7 +250,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -273,7 +272,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -283,7 +282,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "b", ID: "b",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey2, Peer: remotePeerKey2,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -300,7 +299,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -328,7 +327,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -357,7 +356,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "l1", ID: "l1",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -377,7 +376,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "r1", ID: "r1",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@@ -441,11 +440,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
} }
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
} }
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
require.NoError(t, err, "should update routes") require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected expectedWatchers := testCase.clientNetworkWatchersExpected
@@ -454,8 +453,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
} }
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
if runtime.GOOS == "linux" && routeManager.serverRouter != nil { if routeManager.serverRouter != nil {
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match") require.Equal(t, testCase.serverRoutesExpected, routeManager.serverRouter.RoutesCount(), "server networks size should match")
} }
}) })
} }

View File

@@ -1,4 +1,4 @@
package routemanager package server
import ( import (
"context" "context"
@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type serverRouter struct { type Router struct {
mux sync.Mutex mux sync.Mutex
ctx context.Context ctx context.Context
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
@@ -23,8 +23,8 @@ type serverRouter struct {
statusRecorder *peer.Status statusRecorder *peer.Status
} }
func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) { func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*Router, error) {
return &serverRouter{ return &Router{
ctx: ctx, ctx: ctx,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),
firewall: firewall, firewall: firewall,
@@ -33,104 +33,110 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi
}, nil }, nil
} }
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error { func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
m.mux.Lock() r.mux.Lock()
defer m.mux.Unlock() defer r.mux.Unlock()
serverRoutesToRemove := make([]route.ID, 0) serverRoutesToRemove := make([]route.ID, 0)
for routeID := range m.routes { for routeID := range r.routes {
update, found := routesMap[routeID] update, found := routesMap[routeID]
if !found || !update.Equal(m.routes[routeID]) { if !found || !update.Equal(r.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID) serverRoutesToRemove = append(serverRoutesToRemove, routeID)
} }
} }
for _, routeID := range serverRoutesToRemove { for _, routeID := range serverRoutesToRemove {
oldRoute := m.routes[routeID] oldRoute := r.routes[routeID]
err := m.removeFromServerNetwork(oldRoute) err := r.removeFromServerNetwork(oldRoute)
if err != nil { if err != nil {
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err) oldRoute.ID, oldRoute.Network, err)
} }
delete(m.routes, routeID) delete(r.routes, routeID)
} }
// If routing is to be disabled, do it after routes have been removed // If routing is to be disabled, do it after routes have been removed
// If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled // If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled
if len(routesMap) > 0 { if len(routesMap) > 0 {
if err := m.firewall.EnableRouting(); err != nil { if err := r.firewall.EnableRouting(); err != nil {
return fmt.Errorf("enable routing: %w", err) return fmt.Errorf("enable routing: %w", err)
} }
} else { } else {
if err := m.firewall.DisableRouting(); err != nil { if err := r.firewall.DisableRouting(); err != nil {
return fmt.Errorf("disable routing: %w", err) return fmt.Errorf("disable routing: %w", err)
} }
} }
for id, newRoute := range routesMap { for id, newRoute := range routesMap {
_, found := m.routes[id] _, found := r.routes[id]
if found { if found {
continue continue
} }
err := m.addToServerNetwork(newRoute, useNewDNSRoute) err := r.addToServerNetwork(newRoute, useNewDNSRoute)
if err != nil { if err != nil {
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
continue continue
} }
m.routes[id] = newRoute r.routes[id] = newRoute
} }
return nil return nil
} }
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { func (r *Router) removeFromServerNetwork(route *route.Route) error {
if m.ctx.Err() != nil { if r.ctx.Err() != nil {
log.Infof("Not removing from server network because context is done") log.Infof("Not removing from server network because context is done")
return m.ctx.Err() return r.ctx.Err()
} }
routerPair := routeToRouterPair(route, false) routerPair := routeToRouterPair(route, false)
if err := m.firewall.RemoveNatRule(routerPair); err != nil { if err := r.firewall.RemoveNatRule(routerPair); err != nil {
return fmt.Errorf("remove routing rules: %w", err) return fmt.Errorf("remove routing rules: %w", err)
} }
delete(m.routes, route.ID) delete(r.routes, route.ID)
m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString()) r.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
return nil return nil
} }
func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error { func (r *Router) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
if m.ctx.Err() != nil { if r.ctx.Err() != nil {
log.Infof("Not adding to server network because context is done") log.Infof("Not adding to server network because context is done")
return m.ctx.Err() return r.ctx.Err()
} }
routerPair := routeToRouterPair(route, useNewDNSRoute) routerPair := routeToRouterPair(route, useNewDNSRoute)
if err := m.firewall.AddNatRule(routerPair); err != nil { if err := r.firewall.AddNatRule(routerPair); err != nil {
return fmt.Errorf("insert routing rules: %w", err) return fmt.Errorf("insert routing rules: %w", err)
} }
m.routes[route.ID] = route r.routes[route.ID] = route
m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID()) r.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
return nil return nil
} }
func (m *serverRouter) cleanUp() { func (r *Router) CleanUp() {
m.mux.Lock() r.mux.Lock()
defer m.mux.Unlock() defer r.mux.Unlock()
for _, r := range m.routes { for _, route := range r.routes {
routerPair := routeToRouterPair(r, false) routerPair := routeToRouterPair(route, false)
if err := m.firewall.RemoveNatRule(routerPair); err != nil { if err := r.firewall.RemoveNatRule(routerPair); err != nil {
log.Errorf("Failed to remove cleanup route: %v", err) log.Errorf("Failed to remove cleanup route: %v", err)
} }
} }
m.statusRecorder.CleanLocalPeerStateRoutes() r.statusRecorder.CleanLocalPeerStateRoutes()
}
func (r *Router) RoutesCount() int {
r.mux.Lock()
defer r.mux.Unlock()
return len(r.routes)
} }
func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair { func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair {

View File

@@ -24,19 +24,22 @@ func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allo
} }
} }
// Route route methods
func (r *Route) String() string { func (r *Route) String() string {
return r.route.Network.String() return r.route.Network.String()
} }
func (r *Route) AddRoute(context.Context) error { func (r *Route) AddRoute(context.Context) error {
_, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}) if _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}); err != nil {
return err return err
}
return nil
} }
func (r *Route) RemoveRoute() error { func (r *Route) RemoveRoute() error {
_, err := r.routeRefCounter.Decrement(r.route.Network) if _, err := r.routeRefCounter.Decrement(r.route.Network); err != nil {
return err return err
}
return nil
} }
func (r *Route) AddAllowedIPs(peerKey string) error { func (r *Route) AddAllowedIPs(peerKey string) error {
@@ -52,6 +55,8 @@ func (r *Route) AddAllowedIPs(peerKey string) error {
} }
func (r *Route) RemoveAllowedIPs() error { func (r *Route) RemoveAllowedIPs() error {
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network) if _, err := r.allowedIPsRefcounter.Decrement(r.route.Network); err != nil {
return err return err
}
return nil
} }

View File

@@ -13,7 +13,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
const ( const (
@@ -22,8 +22,13 @@ const (
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark" srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
) )
type iface interface {
Address() wgaddr.Address
Name() string
}
// Setup configures sysctl settings for RP filtering and source validation. // Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface iface.WGIface) (map[string]int, error) { func Setup(wgIface iface) (map[string]int, error) {
keys := map[string]int{} keys := map[string]int{}
var result *multierror.Error var result *multierror.Error

View File

@@ -6,9 +6,10 @@ import (
"net/netip" "net/netip"
"sync" "sync"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
) )
type Nexthop struct { type Nexthop struct {
@@ -30,11 +31,16 @@ func (n Nexthop) String() string {
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name) return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
} }
type wgIface interface {
Address() wgaddr.Address
Name() string
}
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct { type SysOps struct {
refCounter *ExclusionCounter refCounter *ExclusionCounter
wgInterface iface.WGIface wgInterface wgIface
// prefixes is tracking all the current added prefixes im memory // prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update) // (this is used in iOS as all route updates require a full table update)
//nolint //nolint
@@ -45,9 +51,27 @@ type SysOps struct {
notifier *notifier.Notifier notifier *notifier.Notifier
} }
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier, notifier: notifier,
} }
} }
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr := prefix.Addr()
switch {
case
!addr.IsValid(),
addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsMulticast(),
addr.IsUnspecified() && prefix.Bits() != 0,
r.wgInterface.Address().Network.Contains(addr):
return vars.ErrRouteNotAllowed
}
return nil
}

View File

@@ -8,6 +8,8 @@ import (
"net/netip" "net/netip"
"os/exec" "os/exec"
"regexp" "regexp"
"runtime"
"strings"
"sync" "sync"
"testing" "testing"
@@ -33,7 +35,12 @@ func init() {
func TestConcurrentRoutes(t *testing.T) { func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0") baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil) r := NewSysOps(nil, nil)
@@ -43,7 +50,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) { go func(ip netip.Addr) {
defer wg.Done() defer wg.Done()
prefix := netip.PrefixFrom(ip, 32) prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err) t.Errorf("Failed to add route for %s: %v", prefix, err)
} }
}(baseIP) }(baseIP)
@@ -59,7 +66,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) { go func(ip netip.Addr) {
defer wg.Done() defer wg.Done()
prefix := netip.PrefixFrom(ip, 32) prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err) t.Errorf("Failed to remove route for %s: %v", prefix, err)
} }
}(baseIP) }(baseIP)
@@ -119,18 +126,39 @@ func TestBits(t *testing.T) {
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper() t.Helper()
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() if runtime.GOOS == "darwin" {
require.NoError(t, err, "Failed to create loopback alias") err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() { t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove loopback alias") assert.NoError(t, err, "Failed to remove route from table")
}) })
return "lo0" return intf
} }
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper() t.Helper()
var originalNexthop net.IP var originalNexthop net.IP
@@ -176,12 +204,40 @@ func fetchOriginalGateway() (net.IP, error) {
return net.ParseIP(matches[1]), nil return net.ParseIP(matches[1]), nil
} }
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) { func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper() t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
} }

View File

@@ -17,7 +17,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
@@ -106,59 +105,15 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil return nil
} }
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values. // If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) { func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr() if err := r.validateRoute(prefix); err != nil {
switch { return Nexthop{}, err
case addr.IsLoopback(), }
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
addr := prefix.Addr()
if addr.IsUnspecified() {
return Nexthop{}, vars.ErrRouteNotAllowed return Nexthop{}, vars.ErrRouteNotAllowed
} }
@@ -179,10 +134,7 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface
Intf: nexthop.Intf, Intf: nexthop.Intf,
} }
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) vpnAddr := vpnIntf.Address().IP
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
@@ -271,32 +223,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
return nil return nil
} }
return r.addNonExistingRoute(prefix, intf) return r.addToRouteTable(prefix, nextHop)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
} }
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, // genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
@@ -408,12 +335,8 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil { if gateway == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil { if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound return Nexthop{Intf: intf}, nil
} }
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc) log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
@@ -457,32 +380,6 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
return addr.Unmap(), nil return addr.Unmap(), nil
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting() localRoutes, err := hasSeparateRouting()

View File

@@ -3,23 +3,25 @@
package systemops package systemops
import ( import (
"bytes"
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os" "os/exec"
"runtime" "runtime"
"strconv"
"strings" "strings"
"syscall"
"testing" "testing"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
) )
type dialer interface { type dialer interface {
@@ -27,105 +29,370 @@ type dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
func TestAddRemoveRoutes(t *testing.T) { func TestAddVPNRoute(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
prefix netip.Prefix prefix netip.Prefix
shouldRouteToWireguard bool expectError bool
shouldBeRemoved bool
}{ }{
{ {
name: "Should Add And Remove Route 100.66.120.0/24", name: "IPv4 - Private network route",
prefix: netip.MustParsePrefix("100.66.120.0/24"), prefix: netip.MustParsePrefix("10.10.100.0/24"),
shouldRouteToWireguard: true,
shouldBeRemoved: true,
}, },
{ {
name: "Should Not Add Or Remove Route 127.0.0.1/32", name: "IPv4 Single host",
prefix: netip.MustParsePrefix("127.0.0.1/32"), prefix: netip.MustParsePrefix("10.111.111.111/32"),
shouldRouteToWireguard: false, },
shouldBeRemoved: false, {
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Default route",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"),
},
{
name: "IPv6 Default route",
prefix: netip.MustParsePrefix("::/0"),
},
// IPv4 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local multicast",
prefix: netip.MustParsePrefix("224.0.0.251/32"),
expectError: true,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
},
{
name: "IPv4 Unspecified with prefix",
prefix: netip.MustParsePrefix("0.0.0.0/32"),
expectError: true,
},
// IPv6 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local multicast",
prefix: netip.MustParsePrefix("ff02::1/128"),
expectError: true,
},
{
name: "IPv6 Interface-local multicast",
prefix: netip.MustParsePrefix("ff01::1/128"),
expectError: true,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
},
{
name: "IPv6 Unspecified with prefix",
prefix: netip.MustParsePrefix("::/128"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network subnet",
prefix: netip.MustParsePrefix("100.65.75.0/32"),
expectError: true,
}, },
} }
for n, testCase := range testCases { for n, testCase := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", testCase.name, " on freebsd")
}
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey() wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
_, _, err = r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
}) })
index, err := net.InterfaceByName(wgInterface.Name()) intf, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err)
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// add the route
err = r.AddVPNRoute(testCase.prefix, intf) err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.expectError {
assert.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
}
if testCase.shouldRouteToWireguard { // validate it's pointing to the WireGuard interface
assertWGOutInterface(t, testCase.prefix, wgInterface, false) require.NoError(t, err)
nextHop := getNextHop(t, testCase.prefix.Addr())
assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface")
// remove route again
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err)
// validate it's gone
nextHop, err = GetNextHop(testCase.prefix.Addr())
require.True(t,
errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(),
"err: %v, next hop: %v", err, nextHop)
})
}
}
func getNextHop(t *testing.T, addr netip.Addr) Nexthop {
t.Helper()
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
nextHop, err := GetNextHop(addr)
if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() {
// TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is
// present in the route table.
t.Skip("Skipping windows test")
}
require.NoError(t, err)
require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr)
return nextHop
}
// GetNextHop for bsd is buggy and returns the wrong interface for the default route.
if addr.IsUnspecified() {
// On macOS, querying 0.0.0.0 returns the wrong interface
if addr.Is4() {
addr = netip.MustParseAddr("1.2.3.4")
} else {
addr = netip.MustParseAddr("2001:db8::1")
}
}
cmd := exec.Command("route", "-n", "get", addr.String())
if addr.Is6() {
cmd = exec.Command("route", "-n", "get", "-inet6", addr.String())
}
output, err := cmd.CombinedOutput()
t.Logf("route output: %s", output)
require.NoError(t, err, "%s failed")
lines := strings.Split(string(output), "\n")
var intf string
var gateway string
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "interface:") {
intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
} else if strings.HasPrefix(line, "gateway:") {
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
}
}
require.NotEmpty(t, intf, "interface should be found in route output")
iface, err := net.InterfaceByName(intf)
require.NoError(t, err, "interface %s should exist", intf)
nexthop := Nexthop{Intf: iface}
if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) {
addr, err := netip.ParseAddr(gateway)
if err == nil {
nexthop.IP = addr
}
}
return nexthop
}
func TestAddRouteToNonVPNIntf(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
expectError bool
errorType error
}{
{
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Single host",
prefix: netip.MustParsePrefix("8.8.8.8/32"),
},
{
name: "IPv6 External network route",
prefix: netip.MustParsePrefix("2001:db8:1000::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2001:db8::1/128"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"),
},
// Addresses that should be rejected
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Unspecified",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Unspecified",
prefix: netip.MustParsePrefix("::/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
require.NoError(t, err, "Should be able to get IPv4 default route")
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if testCase.prefix.Addr().Is6() &&
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
t.Skip("Skipping test as no ipv6 default route is available")
}
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
t.Fatalf("Failed to get IPv6 default route: %v", err)
}
var initialNextHop Nexthop
if testCase.prefix.Addr().Is6() {
initialNextHop = initialNextHopV6
} else { } else {
assertWGOutInterface(t, testCase.prefix, wgInterface, true) initialNextHop = initialNextHopV4
} }
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixNexthop, err := GetNextHop(testCase.prefix.Addr()) nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop)
require.NoError(t, err, "GetNextHop should not return err")
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) if testCase.expectError {
require.NoError(t, err) require.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
if testCase.shouldBeRemoved {
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
}
} }
require.NoError(t, err)
t.Logf("Next hop for %s: %s", testCase.prefix, nexthop)
// Verify the route was added and points to non-VPN interface
currentNextHop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err)
assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface")
err = r.removeFromRouteTable(testCase.prefix, nexthop)
assert.NoError(t, err)
}) })
} }
} }
func TestGetNextHop(t *testing.T) { func TestGetNextHop(t *testing.T) {
if runtime.GOOS == "freebsd" { defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Skip("skipping on freebsd")
}
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil { if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err) t.Fatal("shouldn't return error when fetching the gateway: ", err)
} }
if !nexthop.IP.IsValid() { if !defaultNh.IP.IsValid() {
t.Fatal("should return a gateway") t.Fatal("should return a gateway")
} }
addresses, err := net.InterfaceAddrs() addresses, err := net.InterfaceAddrs()
@@ -133,7 +400,6 @@ func TestGetNextHop(t *testing.T) {
t.Fatal("shouldn't return error when fetching interface addresses: ", err) t.Fatal("shouldn't return error when fetching interface addresses: ", err)
} }
var testingIP string
var testingPrefix netip.Prefix var testingPrefix netip.Prefix
for _, address := range addresses { for _, address := range addresses {
if address.Network() != "ip+net" { if address.Network() != "ip+net" {
@@ -141,213 +407,23 @@ func TestGetNextHop(t *testing.T) {
} }
prefix := netip.MustParsePrefix(address.String()) prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked() testingPrefix = prefix.Masked()
break break
} }
} }
localIP, err := GetNextHop(testingPrefix.Addr()) nh, err := GetNextHop(testingPrefix.Addr())
if err != nil { if err != nil {
t.Fatal("shouldn't return error: ", err) t.Fatal("shouldn't return error: ", err)
} }
if !localIP.IP.IsValid() { if nh.Intf == nil {
t.Fatal("should return a gateway for local network") t.Fatal("should return a gateway for local network")
} }
if localIP.IP.String() == nexthop.IP.String() { if nh.IP.String() == defaultNh.IP.String() {
t.Fatal("local IP should not match with gateway IP") t.Fatal("next hop IP should not match with default gateway IP")
} }
if localIP.IP.String() != testingIP { if nh.Intf.Name != defaultNh.Intf.Name {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String()) t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name)
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
testCases := []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Add And Remove random Route",
prefix: netip.MustParsePrefix("99.99.99.99/32"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
shouldAddRoute: false,
},
{
name: "Should Add Route if bigger network exists",
prefix: netip.MustParsePrefix("100.100.100.0/24"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPort: 33100,
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface, nil)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
// test if route exists after adding
ok, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "should not return err")
require.True(t, ok, "route should exist")
// remove route again if added
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
// route should either not have been added or should have been removed
// In case of already existing route, it should not have been added (but still exist)
ok, err := existsInRouteTable(testCase.prefix)
t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
switch {
case p.Addr().Is6():
continue
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
continue
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
continue
// FreeBSD loopback 127/8 is not added to the routing table
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
continue
default:
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
} }
} }
@@ -384,11 +460,16 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) { func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper() t.Helper()
err := r.AddVPNRoute(prefix, intf) if err := r.AddVPNRoute(prefix, intf); err != nil {
require.NoError(t, err, "addVPNRoute should not return err") if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) {
t.Fatalf("addVPNRoute should not return err: %v", err)
}
t.Logf("addVPNRoute %v returned: %v", prefix, err)
}
t.Cleanup(func() { t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf) if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) {
assert.NoError(t, err, "removeVPNRoute should not return err") t.Fatalf("removeVPNRoute should not return err: %v", err)
}
}) })
} }
@@ -422,28 +503,10 @@ func setupTestEnv(t *testing.T) {
// 10.10.0.0/24 more specific route exists in vpn table // 10.10.0.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
// 127.0.10.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
// unique route in vpn table // unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
} }
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixNexthop, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
}
}
func TestIsVpnRoute(t *testing.T) { func TestIsVpnRoute(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -149,6 +149,10 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
} }
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
return r.genericAddVPNRoute(prefix, intf) return r.genericAddVPNRoute(prefix, intf)
} }
@@ -172,6 +176,10 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
} }
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
return r.genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }
@@ -219,7 +227,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
ones, _ := route.Dst.Mask.Size() ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones) prefix := netip.PrefixFrom(addr.Unmap(), ones)
if prefix.IsValid() { if prefix.IsValid() {
prefixList = append(prefixList, prefix) prefixList = append(prefixList, prefix)
} }
@@ -247,7 +255,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
return fmt.Errorf("add gateway and device: %w", err) return fmt.Errorf("add gateway and device: %w", err)
} }
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
return fmt.Errorf("netlink add route: %w", err) return fmt.Errorf("netlink add route: %w", err)
} }
@@ -270,7 +278,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet, Dst: ipNet,
} }
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
return fmt.Errorf("netlink add unreachable route: %w", err) return fmt.Errorf("netlink add unreachable route: %w", err)
} }

View File

@@ -19,7 +19,6 @@ import (
) )
var expectedVPNint = "wgtest0" var expectedVPNint = "wgtest0"
var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0" var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0" var expectedInternalInt = "dummyint0"
@@ -31,12 +30,6 @@ func init() {
dialer: &net.Dialer{}, dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
}, },
{
name: "To more specific route (local) without custom dialer via physical interface",
expectedInterface: expectedLoopbackInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}...) }...)
} }

View File

@@ -11,10 +11,16 @@ import (
) )
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericAddVPNRoute(prefix, intf) return r.genericAddVPNRoute(prefix, intf)
} }
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }

View File

@@ -0,0 +1,268 @@
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type mockWGIface struct {
address wgaddr.Address
name string
}
func (m *mockWGIface) Address() wgaddr.Address {
return m.address
}
func (m *mockWGIface) Name() string {
return m.name
}
func TestSysOps_validateRoute(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
}{
// Valid routes
{
name: "valid IPv4 route",
prefix: "192.168.1.0/24",
expectError: false,
},
{
name: "valid IPv6 route",
prefix: "2001:db8::/32",
expectError: false,
},
{
name: "valid single IPv4 host",
prefix: "8.8.8.8/32",
expectError: false,
},
{
name: "valid single IPv6 host",
prefix: "2001:4860:4860::8888/128",
expectError: false,
},
// Invalid routes - loopback
{
name: "IPv4 loopback",
prefix: "127.0.0.1/32",
expectError: true,
},
{
name: "IPv6 loopback",
prefix: "::1/128",
expectError: true,
},
// Invalid routes - link-local unicast
{
name: "IPv4 link-local unicast",
prefix: "169.254.1.1/32",
expectError: true,
},
{
name: "IPv6 link-local unicast",
prefix: "fe80::1/128",
expectError: true,
},
// Invalid routes - multicast
{
name: "IPv4 multicast",
prefix: "224.0.0.1/32",
expectError: true,
},
{
name: "IPv6 multicast",
prefix: "ff02::1/128",
expectError: true,
},
// Invalid routes - link-local multicast
{
name: "IPv4 link-local multicast",
prefix: "224.0.0.0/24",
expectError: true,
},
{
name: "IPv6 link-local multicast",
prefix: "ff02::/16",
expectError: true,
},
// Invalid routes - interface-local multicast (IPv6 only)
{
name: "IPv6 interface-local multicast",
prefix: "ff01::1/128",
expectError: true,
},
// Invalid routes - overlaps with WG interface network
{
name: "overlaps with WG network - exact match",
prefix: "10.0.0.0/24",
expectError: true,
},
{
name: "overlaps with WG network - subset",
prefix: "10.0.0.1/32",
expectError: true,
},
{
name: "overlaps with WG network - host in range",
prefix: "10.0.0.100/32",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s", tt.prefix)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix)
}
})
}
}
func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) {
wgNetwork := netip.MustParsePrefix("192.168.100.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
description string
}{
{
name: "identical subnet",
prefix: "192.168.100.0/24",
expectError: true,
description: "exact same network as WG interface",
},
{
name: "broader subnet containing WG network",
prefix: "192.168.0.0/16",
expectError: false,
description: "broader network that contains WG network should be allowed",
},
{
name: "host within WG network",
prefix: "192.168.100.50/32",
expectError: true,
description: "specific host within WG network",
},
{
name: "subnet within WG network",
prefix: "192.168.100.128/25",
expectError: true,
description: "smaller subnet within WG network",
},
{
name: "adjacent subnet - same /23",
prefix: "192.168.101.0/24",
expectError: false,
description: "adjacent subnet, no overlap",
},
{
name: "adjacent subnet - different /16",
prefix: "192.167.100.0/24",
expectError: false,
description: "different network, no overlap",
},
{
name: "WG network broadcast address",
prefix: "192.168.100.255/32",
expectError: true,
description: "broadcast address of WG network",
},
{
name: "WG network first usable",
prefix: "192.168.100.1/32",
expectError: true,
description: "first usable address in WG network",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description)
}
})
}
}
func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wt0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
var invalidPrefix netip.Prefix
err := sysOps.validateRoute(invalidPrefix)
require.Error(t, err, "validateRoute() expected error for invalid prefix")
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix")
}

View File

@@ -3,15 +3,19 @@
package systemops package systemops
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os/exec" "strconv"
"strings" "syscall"
"time" "time"
"unsafe"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -26,48 +30,16 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
} }
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("add", prefix, nexthop) return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
} }
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("delete", prefix, nexthop) return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
} }
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
inet := "-inet" if !prefix.IsValid() {
if prefix.Addr().Is6() { return fmt.Errorf("invalid prefix: %s", prefix)
inet = "-inet6"
}
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
args := []string{"-n", action, inet, network}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else if nexthop.Intf != nil {
args = append(args, "-interface", nexthop.Intf.Name)
}
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
} }
expBackOff := backoff.NewExponentialBackOff() expBackOff := backoff.NewExponentialBackOff()
@@ -75,9 +47,157 @@ func retryRouteCmd(args []string) error {
expBackOff.MaxInterval = 500 * time.Millisecond expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff) if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
if err != nil { a := "add"
return fmt.Errorf("route cmd retry failed: %w", err) if action == unix.RTM_DELETE {
a = "remove"
}
return fmt.Errorf("%s route for %s: %w", a, prefix, err)
} }
return nil return nil
} }
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
}
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
}
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Version: unix.RTM_VERSION,
Seq: 1,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr())
if err != nil {
return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err)
}
if prefix.IsSingleIP() {
msg.Flags |= unix.RTF_HOST
} else {
addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix)
if err != nil {
return nil, fmt.Errorf("build netmask for %s: %w", prefix, err)
}
}
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err)
}
} else if nexthop.Intf != nil {
msg.Index = nexthop.Intf.Index
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {
return &route.Inet4Addr{IP: addr.As4()}, nil
}
if addr.Zone() == "" {
return &route.Inet6Addr{IP: addr.As16()}, nil
}
var zone int
// zone can be either a numeric zone ID or an interface name.
if z, err := strconv.Atoi(addr.Zone()); err == nil {
zone = z
} else {
iface, err := net.InterfaceByName(addr.Zone())
if err != nil {
return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err)
}
zone = iface.Index
}
return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil
}
func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
m := net.CIDRMask(bits, 32)
var maskBytes [4]byte
copy(maskBytes[:], m)
return &route.Inet4Addr{IP: maskBytes}, nil
}
if prefix.Addr().Is6() {
m := net.CIDRMask(bits, 128)
var maskBytes [16]byte
copy(maskBytes[:], m)
return &route.Inet6Addr{IP: maskBytes}, nil
}
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
}

Some files were not shown because too many files have changed in this diff Show More