mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 19:26:39 +00:00
Compare commits
34 Commits
v0.45.2
...
add-ns-pun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff5eddf70b | ||
|
|
0f050e5fe1 | ||
|
|
0f7c7f1da2 | ||
|
|
b56f61bf1b | ||
|
|
64f111923e | ||
|
|
122a89c02b | ||
|
|
c6cceba381 | ||
|
|
6c0cdb6ed1 | ||
|
|
84354951d3 | ||
|
|
55957a1960 | ||
|
|
df82a45d99 | ||
|
|
9424b88db2 | ||
|
|
609654eee7 | ||
|
|
b604c66140 | ||
|
|
ea4d13e96d | ||
|
|
87148c503f | ||
|
|
0cd36baf67 | ||
|
|
06980e7fa0 | ||
|
|
1ce4ee0cef | ||
|
|
f367925496 | ||
|
|
616b19c064 | ||
|
|
af27aaf9af | ||
|
|
35287f8241 | ||
|
|
07b220d91b | ||
|
|
41cd4952f1 | ||
|
|
f16f0c7831 | ||
|
|
273160c682 | ||
|
|
1d6c360aec | ||
|
|
f04e7c3f06 | ||
|
|
3d89cd43c2 | ||
|
|
0eeda712d0 | ||
|
|
3e3268db5f | ||
|
|
31f0879e71 | ||
|
|
f25b5bb987 |
1
.github/workflows/golangci-lint.yml
vendored
1
.github/workflows/golangci-lint.yml
vendored
@@ -21,7 +21,6 @@ jobs:
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -172,11 +172,11 @@ jobs:
|
||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
# check relay values
|
||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||
grep '33445:33445' docker-compose.yml
|
||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<br>
|
||||
@@ -29,7 +29,7 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
|
||||
@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
||||
return a.ipAnonymizer[ip]
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
|
||||
// Convert IP to netip.Addr
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return addr
|
||||
}
|
||||
|
||||
anonIP := a.AnonymizeIP(ip)
|
||||
|
||||
return net.UDPAddr{
|
||||
IP: anonIP.AsSlice(),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
}
|
||||
|
||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||
|
||||
@@ -39,7 +39,6 @@ const (
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
@@ -78,7 +77,6 @@ var (
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
blockLANAccess bool
|
||||
debugUploadBundle bool
|
||||
debugUploadBundleURL string
|
||||
lazyConnEnabled bool
|
||||
|
||||
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -27,12 +28,19 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
}
|
||||
|
||||
func newSVCConfig() *service.Config {
|
||||
return &service.Config{
|
||||
config := &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: "Netbird",
|
||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||
Description: "Netbird mesh network client",
|
||||
Option: make(service.KeyValue),
|
||||
EnvVars: make(map[string]string),
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
config.EnvVars["SYSTEMD_UNIT"] = serviceName
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
||||
|
||||
@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||
}
|
||||
|
||||
if logFile != "console" {
|
||||
if logFile != "" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ const (
|
||||
disableServerRoutesFlag = "disable-server-routes"
|
||||
disableDNSFlag = "disable-dns"
|
||||
disableFirewallFlag = "disable-firewall"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
blockInboundFlag = "block-inbound"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -13,6 +15,8 @@ var (
|
||||
disableServerRoutes bool
|
||||
disableDNS bool
|
||||
disableFirewall bool
|
||||
blockLANAccess bool
|
||||
blockInbound bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -28,4 +32,11 @@ func init() {
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
||||
Example: `
|
||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||
Args: cobra.ExactArgs(3),
|
||||
RunE: tracePacket,
|
||||
|
||||
266
client/cmd/up.go
266
client/cmd/up.go
@@ -55,12 +55,11 @@ func init() {
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels`+
|
||||
@@ -119,83 +118,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(wireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
if autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
if !autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
ic, err := setupConfig(customDNSAddressConverted, cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup config: %v", err)
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
@@ -203,7 +128,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
config, err := internal.UpdateOrCreateConfig(*ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
@@ -262,9 +187,141 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get setup key: %v", err)
|
||||
}
|
||||
|
||||
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup login request: %v", err)
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(wireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
if autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
if !autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
ic.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
@@ -301,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
loginRequest.InterfaceName = &interfaceName
|
||||
}
|
||||
@@ -336,49 +393,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
loginRequest.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
return &loginRequest, nil
|
||||
}
|
||||
|
||||
func validateNATExternalIPs(list []string) error {
|
||||
|
||||
@@ -147,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -198,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
"all",
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
@@ -219,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Close(nil)
|
||||
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -248,10 +248,6 @@ func (r *router) deleteIpSet(setName string) error {
|
||||
|
||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
@@ -278,10 +274,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
|
||||
@@ -116,6 +116,8 @@ type Manager interface {
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
IsStateful() bool
|
||||
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
|
||||
@@ -170,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -324,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package nftables
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs2 := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
Data: ip.AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
ip := netip.MustParseAddr("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
|
||||
@@ -573,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -1006,10 +1002,6 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
||||
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ type Forwarder struct {
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip net.IP
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
}
|
||||
|
||||
@@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
ones, _ := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: ones,
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
PrefixLen: iface.Address().Network.Bits(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: iface.Address().IP,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
@@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
|
||||
}
|
||||
|
||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
|
||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return value.([]byte), true
|
||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||
|
||||
@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
|
||||
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
|
||||
}
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
|
||||
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
wg.Wait()
|
||||
|
||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
|
||||
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
|
||||
}
|
||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
|
||||
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
|
||||
@@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
if !ip.Is4() {
|
||||
return
|
||||
}
|
||||
ipv4 := ip.AsSlice()
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
ipStr := ipv4.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
}
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
if _, exists := ipv4Set[ip]; !exists {
|
||||
ipv4Set[ip] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
ipv4Set := make(map[netip.Addr]struct{})
|
||||
var ipv4Addresses []netip.Addr
|
||||
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
|
||||
@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
expected: true,
|
||||
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
expected: true,
|
||||
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
expected: true,
|
||||
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
expected: true,
|
||||
@@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
@@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
expected: false,
|
||||
@@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
IP: netip.MustParseAddr("fe80::1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
expected: false,
|
||||
|
||||
@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -39,8 +39,12 @@ const (
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
||||
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
|
||||
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
|
||||
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
|
||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||
)
|
||||
|
||||
@@ -71,7 +75,6 @@ type Manager struct {
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
@@ -148,6 +151,11 @@ func parseCreateEnv() (bool, bool) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||
}
|
||||
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
|
||||
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
@@ -269,7 +277,7 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||
case !m.netstack && m.nativeFirewall != nil:
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
@@ -326,6 +334,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return m.stateful
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
@@ -606,9 +618,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
if m.stateful {
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
}
|
||||
// for netflow we keep track even if the firewall is stateless
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -777,9 +788,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return true
|
||||
}
|
||||
|
||||
// if running in netstack mode we need to pass this to the forwarder
|
||||
if m.netstack && m.localForwarding {
|
||||
return m.handleNetstackLocalTraffic(packetData)
|
||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||
return m.handleForwardedLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
// track inbound packets to get the correct direction and session id for flows
|
||||
@@ -789,8 +801,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||
fwd := m.forwarder.Load()
|
||||
if fwd == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
@@ -1088,11 +1099,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
|
||||
@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Apply scenario-specific setup
|
||||
sc.setupFunc(manager)
|
||||
|
||||
@@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Pre-populate connection table
|
||||
srcIPs := generateRandomIPs(count)
|
||||
dstIPs := generateRandomIPs(count)
|
||||
@@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
srcIP := generateRandomIPs(1)[0]
|
||||
dstIP := generateRandomIPs(1)[0]
|
||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "post_handshake",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||
dst := fw.Network{Prefix: r.dest}
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -19,12 +19,8 @@ import (
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
localIP := net.ParseIP("100.10.0.100")
|
||||
wgNet := &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
localIP := netip.MustParseAddr("100.10.0.100")
|
||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
@@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = wgNet
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
localIP, wgNet, err := net.ParseCIDR(network)
|
||||
require.NoError(tb, err)
|
||||
wgNet := netip.MustParsePrefix(network)
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
@@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
@@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer func() {
|
||||
@@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||
manager.decoders = sync.Pool{
|
||||
|
||||
@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a.AsSlice()) {
|
||||
if u.address.Network.Contains(a) {
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var zeroKey wgtypes.Key
|
||||
|
||||
type KernelConfigurer struct {
|
||||
deviceName string
|
||||
}
|
||||
@@ -201,6 +203,47 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
||||
func (c *KernelConfigurer) Close() {
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wgctl: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
log.Errorf("Got error while closing wgctl: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(c.deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||
}
|
||||
fullStats := &Stats{
|
||||
DeviceName: wgDevice.Name,
|
||||
PublicKey: wgDevice.PublicKey.String(),
|
||||
ListenPort: wgDevice.ListenPort,
|
||||
FWMark: wgDevice.FirewallMark,
|
||||
Peers: []Peer{},
|
||||
}
|
||||
|
||||
for _, p := range wgDevice.Peers {
|
||||
peer := Peer{
|
||||
PublicKey: p.PublicKey.String(),
|
||||
AllowedIPs: p.AllowedIPs,
|
||||
TxBytes: p.TransmitBytes,
|
||||
RxBytes: p.ReceiveBytes,
|
||||
LastHandshake: p.LastHandshakeTime,
|
||||
PresharedKey: p.PresharedKey != zeroKey,
|
||||
}
|
||||
if p.Endpoint != nil {
|
||||
peer.Endpoint = *p.Endpoint
|
||||
}
|
||||
fullStats.Peers = append(fullStats.Peers, peer)
|
||||
}
|
||||
return fullStats, nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
wg, err := wgctrl.New()
|
||||
|
||||
@@ -19,10 +19,17 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
privateKey = "private_key"
|
||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||
ipcKeyTxBytes = "tx_bytes"
|
||||
ipcKeyRxBytes = "rx_bytes"
|
||||
allowedIP = "allowed_ip"
|
||||
endpoint = "endpoint"
|
||||
fwmark = "fwmark"
|
||||
listenPort = "listen_port"
|
||||
publicKey = "public_key"
|
||||
presharedKey = "preshared_key"
|
||||
)
|
||||
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
@@ -186,6 +193,15 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
||||
ipcStr, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGet failed: %w", err)
|
||||
}
|
||||
|
||||
return parseStatus(c.deviceName, ipcStr)
|
||||
}
|
||||
|
||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||
func (t *WGUSPConfigurer) startUAPI() {
|
||||
var err error
|
||||
@@ -365,3 +381,136 @@ func getFwmark() int {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
|
||||
// Decode hex string to bytes
|
||||
keyBytes, err := hex.DecodeString(hexKey)
|
||||
if err != nil {
|
||||
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
|
||||
}
|
||||
|
||||
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
|
||||
if len(keyBytes) != 32 {
|
||||
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
|
||||
}
|
||||
|
||||
// Convert to wgtypes.Key
|
||||
var key wgtypes.Key
|
||||
copy(key[:], keyBytes)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
stats := &Stats{DeviceName: deviceName}
|
||||
var currentPeer *Peer
|
||||
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
key := parts[0]
|
||||
val := parts[1]
|
||||
|
||||
switch key {
|
||||
case privateKey:
|
||||
key, err := hexToWireguardKey(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse private key: %v", err)
|
||||
continue
|
||||
}
|
||||
stats.PublicKey = key.PublicKey().String()
|
||||
case publicKey:
|
||||
// Save previous peer
|
||||
if currentPeer != nil {
|
||||
stats.Peers = append(stats.Peers, *currentPeer)
|
||||
}
|
||||
key, err := hexToWireguardKey(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse public key: %v", err)
|
||||
continue
|
||||
}
|
||||
currentPeer = &Peer{
|
||||
PublicKey: key.String(),
|
||||
}
|
||||
case listenPort:
|
||||
if port, err := strconv.Atoi(val); err == nil {
|
||||
stats.ListenPort = port
|
||||
}
|
||||
case fwmark:
|
||||
if fwmark, err := strconv.Atoi(val); err == nil {
|
||||
stats.FWMark = fwmark
|
||||
}
|
||||
case endpoint:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint: %v", err)
|
||||
continue
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint port: %v", err)
|
||||
continue
|
||||
}
|
||||
currentPeer.Endpoint = net.UDPAddr{
|
||||
IP: net.ParseIP(host),
|
||||
Port: port,
|
||||
}
|
||||
case allowedIP:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(val)
|
||||
if err == nil {
|
||||
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
|
||||
}
|
||||
case ipcKeyTxBytes:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
rxBytes, err := toBytes(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.TxBytes = rxBytes
|
||||
case ipcKeyRxBytes:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
rxBytes, err := toBytes(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.RxBytes = rxBytes
|
||||
|
||||
case ipcKeyLastHandshakeTimeSec:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ts, err := toLastHandshake(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.LastHandshake = ts
|
||||
case presharedKey:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
if val != "" {
|
||||
currentPeer.PresharedKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentPeer != nil {
|
||||
stats.Peers = append(stats.Peers, *currentPeer)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
24
client/iface/configurer/wgshow.go
Normal file
24
client/iface/configurer/wgshow.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
PublicKey string
|
||||
Endpoint net.UDPAddr
|
||||
AllowedIPs []net.IPNet
|
||||
TxBytes int64
|
||||
RxBytes int64
|
||||
LastHandshake time.Time
|
||||
PresharedKey bool
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
DeviceName string
|
||||
PublicKey string
|
||||
ListenPort int
|
||||
FWMark int
|
||||
Peers []Peer
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||
@@ -43,11 +44,11 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
||||
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains domain.List) (WGConfigurer, error) {
|
||||
log.Info("create tun interface")
|
||||
|
||||
routesString := routesToString(routes)
|
||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||
searchDomainsToString := searchDomainsToString(searchDomains.ToPunycodeList())
|
||||
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -24,9 +23,6 @@ type PacketFilter interface {
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
SetNetwork(*net.IPNet)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
||||
@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
log.Info("create nbnetstack tun interface")
|
||||
|
||||
// TODO: get from service listener runtime IP
|
||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("last ip: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("netstack using address: %s", t.address.IP)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||
|
||||
@@ -17,4 +17,5 @@ type WGConfigurer interface {
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close()
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
}
|
||||
|
||||
@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
}
|
||||
|
||||
ip := address.IP.String()
|
||||
mask := "0x" + address.Network.Mask.String()
|
||||
|
||||
// Convert prefix length to hex netmask
|
||||
prefixLen := address.Network.Bits()
|
||||
if !address.IP.Is4() {
|
||||
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||
}
|
||||
|
||||
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||
|
||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||
|
||||
|
||||
@@ -8,10 +8,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||
Create(routes []string, dns string, searchDomains domain.List) (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
|
||||
@@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
|
||||
}
|
||||
|
||||
w.filter = filter
|
||||
w.filter.SetNetwork(w.tun.WgAddress().Network)
|
||||
|
||||
w.tun.FilteredDevice().SetFilter(filter)
|
||||
return nil
|
||||
@@ -217,6 +216,10 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||
return w.configurer.GetStats()
|
||||
}
|
||||
|
||||
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||
return w.configurer.FullStats()
|
||||
}
|
||||
|
||||
func (w *WGIface) waitUntilRemoved() error {
|
||||
maxWaitTime := 5 * time.Second
|
||||
timeout := time.NewTimer(maxWaitTime)
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
|
||||
package iface
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
@@ -21,6 +25,6 @@ func (w *WGIface) Create() error {
|
||||
}
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
|
||||
return fmt.Errorf("this function has not implemented on non mobile")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,13 @@ package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains domain.List) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
@@ -36,6 +38,6 @@ func (w *WGIface) Create() error {
|
||||
}
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
"net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -15,8 +13,8 @@ import (
|
||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||
|
||||
type NetStackTun struct { //nolint:revive
|
||||
address net.IP
|
||||
dnsAddress net.IP
|
||||
address netip.Addr
|
||||
dnsAddress netip.Addr
|
||||
mtu int
|
||||
listenAddress string
|
||||
|
||||
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
|
||||
tundev tun.Device
|
||||
}
|
||||
|
||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
||||
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||
return &NetStackTun{
|
||||
address: address,
|
||||
dnsAddress: dnsAddress,
|
||||
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
addr, ok := netip.AddrFromSlice(t.address)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
||||
}
|
||||
|
||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
||||
}
|
||||
|
||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||
[]netip.Addr{addr.Unmap()},
|
||||
[]netip.Addr{dnsAddr.Unmap()},
|
||||
[]netip.Addr{t.address},
|
||||
[]netip.Addr{t.dnsAddress},
|
||||
t.mtu)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -2,28 +2,27 @@ package wgaddr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// Address WireGuard parsed address
|
||||
type Address struct {
|
||||
IP net.IP
|
||||
Network *net.IPNet
|
||||
IP netip.Addr
|
||||
Network netip.Prefix
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
func ParseWGAddress(address string) (Address, error) {
|
||||
ip, network, err := net.ParseCIDR(address)
|
||||
prefix, err := netip.ParsePrefix(address)
|
||||
if err != nil {
|
||||
return Address{}, err
|
||||
}
|
||||
return Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
IP: prefix.Addr().Unmap(),
|
||||
Network: prefix.Masked(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (addr Address) String() string {
|
||||
maskSize, _ := addr.Network.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||
}
|
||||
|
||||
@@ -58,6 +58,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
if d.firewall == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
total := 0
|
||||
@@ -69,14 +74,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
time.Since(start), total)
|
||||
}()
|
||||
|
||||
if d.firewall == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
}
|
||||
|
||||
d.applyPeerACLs(networkMap)
|
||||
|
||||
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
}
|
||||
@@ -285,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||
if d.firewall.IsStateful() {
|
||||
return "", nil, nil
|
||||
}
|
||||
// return traffic for outbound connections if firewall is stateless
|
||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Close(nil)
|
||||
}(fw)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
||||
return
|
||||
if fw.IsStateful() {
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||
} else {
|
||||
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -94,12 +91,13 @@ func TestDefaultManager(t *testing.T) {
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied")
|
||||
return
|
||||
expectedRules := 2
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 1 // only the inbound rule
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.peerRulesPairs {
|
||||
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
if previousCount != 1 {
|
||||
t.Errorf("old rule was not removed")
|
||||
|
||||
expectedPreviousCount := 0
|
||||
if !fw.IsStateful() {
|
||||
expectedPreviousCount = 1
|
||||
}
|
||||
assert.Equal(t, expectedPreviousCount, previousCount)
|
||||
})
|
||||
|
||||
t.Run("handle default rules", func(t *testing.T) {
|
||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = true
|
||||
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
|
||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = false
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
if len(acl.peerRulesPairs) != 1 {
|
||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
|
||||
expectedRules := 1
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 1 // only inbound allow-all rule
|
||||
}
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultManagerStateless(t *testing.T) {
|
||||
// stateless currently only in userspace, so we have to disable kernel
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
Port: "53",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// In stateless mode, we should have both inbound and outbound rules
|
||||
assert.False(t, fw.IsStateful())
|
||||
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
if len(rules) != 2 {
|
||||
t.Errorf("rules should contain 2, got: %v", rules)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, 2, len(rules))
|
||||
|
||||
r := rules[0]
|
||||
switch {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
case r.Direction != mgmProto.RuleDirection_IN:
|
||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
||||
return
|
||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
|
||||
r = rules[1]
|
||||
switch {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
case r.Direction != mgmProto.RuleDirection_OUT:
|
||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
||||
return
|
||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
@@ -291,9 +326,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
||||
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
||||
}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
@@ -336,33 +370,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Close(nil)
|
||||
}(fw)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
if len(acl.peerRulesPairs) != 3 {
|
||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
return
|
||||
expectedRules := 3
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||
}
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
}
|
||||
|
||||
@@ -68,8 +68,8 @@ type ConfigInput struct {
|
||||
DisableServerRoutes *bool
|
||||
DisableDNS *bool
|
||||
DisableFirewall *bool
|
||||
|
||||
BlockLANAccess *bool
|
||||
BlockLANAccess *bool
|
||||
BlockInbound *bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
@@ -98,8 +98,8 @@ type Config struct {
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
|
||||
BlockLANAccess bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
@@ -483,6 +483,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
|
||||
if *input.BlockInbound {
|
||||
log.Infof("blocking inbound connections")
|
||||
} else {
|
||||
log.Infof("allowing inbound connections")
|
||||
}
|
||||
config.BlockInbound = *input.BlockInbound
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||
if *input.DisableNotifications {
|
||||
log.Infof("disabling notifications")
|
||||
|
||||
@@ -436,11 +436,12 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||
DisableDNS: config.DisableDNS,
|
||||
DisableFirewall: config.DisableFirewall,
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
BlockInbound: config.BlockInbound,
|
||||
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
}
|
||||
|
||||
|
||||
@@ -270,11 +270,21 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if g.logFile != "console" {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
return fmt.Errorf("add log file: %w", err)
|
||||
}
|
||||
if err := g.addWgShow(); err != nil {
|
||||
log.Errorf("Failed to add wg show output: %v", err)
|
||||
}
|
||||
|
||||
if g.logFile != "console" && g.logFile != "" {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("Failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
log.Errorf("Failed to add systemd logs as fallback: %v", err)
|
||||
}
|
||||
}
|
||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||
log.Errorf("Failed to add systemd logs: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -366,17 +376,33 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
|
||||
if g.internalConfig.ServerSSHAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
|
||||
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
|
||||
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
||||
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
|
||||
|
||||
if g.internalConfig.DisableNotifications != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DNSLabels: %v\n", g.internalConfig.DNSLabels))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
|
||||
|
||||
if g.internalConfig.ClientCertPath != "" {
|
||||
configContent.WriteString(fmt.Sprintf("ClientCertPath: %s\n", g.internalConfig.ClientCertPath))
|
||||
}
|
||||
if g.internalConfig.ClientCertKeyPath != "" {
|
||||
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||
}
|
||||
|
||||
|
||||
@@ -4,17 +4,104 @@ package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
maxLogEntries = 100000
|
||||
maxLogAge = 7 * 24 * time.Hour // Last 7 days
|
||||
)
|
||||
|
||||
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
|
||||
func (g *BundleGenerator) trySystemdLogFallback() error {
|
||||
log.Debug("Attempting to collect systemd journal logs")
|
||||
|
||||
serviceName := getServiceName()
|
||||
journalLogs, err := getSystemdLogs(serviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
|
||||
}
|
||||
|
||||
if strings.Contains(journalLogs, "No recent log entries found") {
|
||||
log.Debug("No recent log entries found in systemd journal")
|
||||
return nil
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
|
||||
}
|
||||
|
||||
logReader := strings.NewReader(journalLogs)
|
||||
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
|
||||
if err := g.addFileToZip(logReader, fileName); err != nil {
|
||||
return fmt.Errorf("add systemd logs to bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getServiceName gets the service name from environment or defaults to netbird
|
||||
func getServiceName() string {
|
||||
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
|
||||
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
|
||||
return unitName
|
||||
}
|
||||
|
||||
return "netbird"
|
||||
}
|
||||
|
||||
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
|
||||
func getSystemdLogs(serviceName string) (string, error) {
|
||||
args := []string{
|
||||
"-u", fmt.Sprintf("%s.service", serviceName),
|
||||
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
|
||||
"--lines", fmt.Sprintf("%d", maxLogEntries),
|
||||
"--no-pager",
|
||||
"--output", "short-iso",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "journalctl", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
|
||||
}
|
||||
if strings.Contains(err.Error(), "executable file not found") {
|
||||
return "", fmt.Errorf("journalctl command not found: %w", err)
|
||||
}
|
||||
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
logs := stdout.String()
|
||||
if strings.TrimSpace(logs) == "" {
|
||||
return "No recent log entries found in systemd journal", nil
|
||||
}
|
||||
|
||||
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
|
||||
serviceName, maxLogEntries, maxLogAge.String())
|
||||
|
||||
return header + logs, nil
|
||||
}
|
||||
|
||||
// addFirewallRules collects and adds firewall rules to the archive
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
log.Info("Collecting firewall rules")
|
||||
@@ -481,7 +568,7 @@ func formatExpr(exp expr.Any) string {
|
||||
case *expr.Fib:
|
||||
return formatFib(e)
|
||||
case *expr.Target:
|
||||
return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets
|
||||
return fmt.Sprintf("jump %s", e.Name)
|
||||
case *expr.Immediate:
|
||||
if e.Register == 1 {
|
||||
return formatImmediateData(e.Data)
|
||||
|
||||
@@ -6,3 +6,9 @@ package debug
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) trySystemdLogFallback() error {
|
||||
// Systemd is only available on Linux
|
||||
// TODO: Add BSD support
|
||||
return nil
|
||||
}
|
||||
|
||||
66
client/internal/debug/wgshow.go
Normal file
66
client/internal/debug/wgshow.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
)
|
||||
|
||||
type WGIface interface {
|
||||
FullStats() (*configurer.Stats, error)
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addWgShow() error {
|
||||
result, err := g.statusRecorder.PeersStatus()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output := g.toWGShowFormat(result)
|
||||
reader := bytes.NewReader([]byte(output))
|
||||
|
||||
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
|
||||
return fmt.Errorf("add wg show to zip: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
|
||||
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
|
||||
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
|
||||
if s.FWMark != 0 {
|
||||
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
|
||||
}
|
||||
|
||||
for _, peer := range s.Peers {
|
||||
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
|
||||
if peer.Endpoint.IP != nil {
|
||||
if g.anonymize {
|
||||
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
|
||||
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
|
||||
}
|
||||
}
|
||||
if len(peer.AllowedIPs) > 0 {
|
||||
var ipStrings []string
|
||||
for _, ipnet := range peer.AllowedIPs {
|
||||
ipStrings = append(ipStrings, ipnet.String())
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
|
||||
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
|
||||
if peer.PresharedKey {
|
||||
sb.WriteString(" preshared key: (hidden)\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -12,13 +12,14 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||
ip := net.ParseIP(aRecord.RData)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||
ip, err := netip.ParseAddr(aRecord.RData)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
if !ipNet.Contains(ip) {
|
||||
if !prefix.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
@@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||
maskOnes, _ := ipNet.Mask.Size()
|
||||
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
||||
networkIP := network.Masked().Addr()
|
||||
|
||||
if !networkIP.Is4() {
|
||||
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
|
||||
}
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (maskOnes + 7) / 8
|
||||
octetsToUse := (network.Bits() + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
@@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
@@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
||||
continue
|
||||
}
|
||||
|
||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||
if ptrRecord, ok := createPTRRecord(record, prefix); ok {
|
||||
records = append(records, ptrRecord)
|
||||
}
|
||||
}
|
||||
@@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
||||
}
|
||||
|
||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
zoneName, err := generateReverseZoneName(ipNet)
|
||||
func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||
zoneName, err := generateReverseZoneName(network)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return
|
||||
@@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
return
|
||||
}
|
||||
|
||||
records := collectPTRRecords(config, ipNet)
|
||||
records := collectPTRRecords(config, network)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
|
||||
@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||
}
|
||||
return listOfDomains
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,8 +26,8 @@ type SubdomainMatcher interface {
|
||||
type HandlerEntry struct {
|
||||
Handler dns.Handler
|
||||
Priority int
|
||||
Pattern string
|
||||
OrigPattern string
|
||||
Pattern domain.Domain
|
||||
OrigPattern domain.Domain
|
||||
IsWildcard bool
|
||||
MatchSubdomains bool
|
||||
}
|
||||
@@ -38,7 +41,7 @@ type HandlerChain struct {
|
||||
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||
type ResponseWriterChain struct {
|
||||
dns.ResponseWriter
|
||||
origPattern string
|
||||
origPattern domain.Domain
|
||||
shouldContinue bool
|
||||
}
|
||||
|
||||
@@ -58,18 +61,18 @@ func NewHandlerChain() *HandlerChain {
|
||||
}
|
||||
|
||||
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
||||
func (w *ResponseWriterChain) GetOrigPattern() string {
|
||||
func (w *ResponseWriterChain) GetOrigPattern() domain.Domain {
|
||||
return w.origPattern
|
||||
}
|
||||
|
||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
||||
func (c *HandlerChain) AddHandler(pattern domain.Domain, handler dns.Handler, priority int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
pattern = domain.Domain(strings.ToLower(dns.Fqdn(pattern.PunycodeString())))
|
||||
origPattern := pattern
|
||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||
isWildcard := strings.HasPrefix(pattern.PunycodeString(), "*.")
|
||||
if isWildcard {
|
||||
pattern = pattern[2:]
|
||||
}
|
||||
@@ -109,8 +112,8 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
||||
|
||||
// domain specificity next
|
||||
if h.Priority == newEntry.Priority {
|
||||
newDots := strings.Count(newEntry.Pattern, ".")
|
||||
existingDots := strings.Count(h.Pattern, ".")
|
||||
newDots := strings.Count(newEntry.Pattern.PunycodeString(), ".")
|
||||
existingDots := strings.Count(h.Pattern.PunycodeString(), ".")
|
||||
if newDots > existingDots {
|
||||
return i
|
||||
}
|
||||
@@ -122,20 +125,20 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
||||
}
|
||||
|
||||
// RemoveHandler removes a handler for the given pattern and priority
|
||||
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
func (c *HandlerChain) RemoveHandler(pattern domain.Domain, priority int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
pattern = domain.Domain(dns.Fqdn(pattern.PunycodeString()))
|
||||
|
||||
c.removeEntry(pattern, priority)
|
||||
}
|
||||
|
||||
func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||
func (c *HandlerChain) removeEntry(pattern domain.Domain, priority int) {
|
||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
if strings.EqualFold(entry.OrigPattern.PunycodeString(), pattern.PunycodeString()) && entry.Priority == priority {
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
@@ -148,61 +151,42 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
qname := strings.ToLower(r.Question[0].Name)
|
||||
log.Tracef("handling DNS request for domain=%s", qname)
|
||||
|
||||
c.mu.RLock()
|
||||
handlers := slices.Clone(c.handlers)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("current handlers (%d):", len(handlers))
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
|
||||
for _, h := range handlers {
|
||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
||||
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
var matched bool
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
matched = true
|
||||
case entry.IsWildcard:
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||
default:
|
||||
// For non-wildcard patterns:
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
// Otherwise require exact match
|
||||
if entry.MatchSubdomains {
|
||||
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
} else {
|
||||
matched = strings.EqualFold(qname, entry.Pattern)
|
||||
matched := c.isHandlerMatch(qname, entry)
|
||||
|
||||
if matched {
|
||||
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
}
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
if !matched {
|
||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
||||
continue
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
log.Tracef("handler requested continue to next handler")
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No handler matched or all handlers passed
|
||||
@@ -213,3 +197,22 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
return true
|
||||
case entry.IsWildcard:
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern.PunycodeString()), ".")
|
||||
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern.PunycodeString())
|
||||
default:
|
||||
// For non-wildcard patterns:
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
// Otherwise require exact match
|
||||
if entry.MatchSubdomains {
|
||||
return strings.EqualFold(qname, entry.Pattern.PunycodeString()) || strings.HasSuffix(qname, "."+entry.Pattern.PunycodeString())
|
||||
} else {
|
||||
return strings.EqualFold(qname, entry.Pattern.PunycodeString())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||
@@ -50,8 +51,8 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
handlerDomain string
|
||||
queryDomain string
|
||||
handlerDomain domain.Domain
|
||||
queryDomain domain.Domain
|
||||
isWildcard bool
|
||||
matchSubdomains bool
|
||||
shouldMatch bool
|
||||
@@ -141,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
@@ -160,17 +161,17 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
handlers []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}
|
||||
queryDomain string
|
||||
queryDomain domain.Domain
|
||||
expectedCalls int
|
||||
expectedHandler int // index of the handler that should be called
|
||||
}{
|
||||
{
|
||||
name: "wildcard and exact same priority - exact should win",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
@@ -183,7 +184,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
{
|
||||
name: "higher priority wildcard over lower priority exact",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}{
|
||||
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||
@@ -196,7 +197,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
{
|
||||
name: "multiple wildcards different priorities",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
@@ -210,7 +211,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
{
|
||||
name: "subdomain with mix of patterns",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
@@ -224,7 +225,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
{
|
||||
name: "root zone with specific domain",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}{
|
||||
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||
@@ -258,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
|
||||
// Create and execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
@@ -330,7 +331,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
name string
|
||||
ops []struct {
|
||||
action string // "add" or "remove"
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}
|
||||
query string
|
||||
@@ -340,7 +341,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
name: "remove high priority keeps lower priority handler",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
@@ -357,7 +358,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
name: "remove lower priority keeps high priority handler",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
@@ -374,7 +375,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
name: "remove all handlers in order",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
@@ -436,7 +437,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
testDomain := "example.com."
|
||||
testDomain := domain.Domain("example.com.")
|
||||
testQuery := "test.example.com."
|
||||
|
||||
// Create handlers with MatchSubdomains enabled
|
||||
@@ -518,7 +519,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
name string
|
||||
scenario string
|
||||
addHandlers []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
@@ -530,7 +531,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
name: "case insensitive exact match",
|
||||
scenario: "handler registered lowercase, query uppercase",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
@@ -544,7 +545,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
name: "case insensitive wildcard match",
|
||||
scenario: "handler registered mixed case wildcard, query different case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
@@ -558,7 +559,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
name: "multiple handlers different case same domain",
|
||||
scenario: "second handler should replace first despite case difference",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
@@ -573,7 +574,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
name: "subdomain matching case insensitive",
|
||||
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
@@ -587,7 +588,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
name: "root zone case insensitive",
|
||||
scenario: "root zone handler should match regardless of case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
@@ -601,7 +602,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
name: "multiple handlers different priority",
|
||||
scenario: "should call higher priority handler despite case differences",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
@@ -618,7 +619,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlerCalls := make(map[string]bool) // track which patterns were called
|
||||
handlerCalls := make(map[domain.Domain]bool) // track which patterns were called
|
||||
|
||||
// Add handlers according to test case
|
||||
for _, h := range tt.addHandlers {
|
||||
@@ -686,19 +687,19 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
scenario string
|
||||
ops []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomain bool
|
||||
}
|
||||
query string
|
||||
expectedMatch string
|
||||
query domain.Domain
|
||||
expectedMatch domain.Domain
|
||||
}{
|
||||
{
|
||||
name: "more specific domain matches first",
|
||||
scenario: "sub.example.com should match before example.com",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
@@ -713,7 +714,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
scenario: "sub.example.com should match before example.com",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
@@ -728,7 +729,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
scenario: "after removing most specific, should fall back to less specific",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
@@ -745,7 +746,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
scenario: "less specific domain with higher priority should match first",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
@@ -760,7 +761,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
scenario: "with equal priority, more specific domain should match",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
@@ -776,7 +777,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
scenario: "specific domain should match before wildcard at same priority",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
pattern domain.Domain
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
@@ -791,7 +792,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlers := make(map[string]*nbdns.MockSubdomainHandler)
|
||||
handlers := make(map[domain.Domain]*nbdns.MockSubdomainHandler)
|
||||
|
||||
for _, op := range tt.ops {
|
||||
if op.action == "add" {
|
||||
@@ -804,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
}
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
r.SetQuestion(tt.query.PunycodeString(), dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// Setup handler expectations
|
||||
@@ -836,9 +837,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addPattern string
|
||||
removePattern string
|
||||
queryPattern string
|
||||
addPattern domain.Domain
|
||||
removePattern domain.Domain
|
||||
queryPattern domain.Domain
|
||||
shouldBeRemoved bool
|
||||
description string
|
||||
}{
|
||||
@@ -954,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
|
||||
handler := &nbdns.MockHandler{}
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||
r.SetQuestion(tt.queryPattern.PunycodeString(), dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// First verify no handler is called before adding any
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
@@ -39,9 +40,9 @@ type HostDNSConfig struct {
|
||||
}
|
||||
|
||||
type DomainConfig struct {
|
||||
Disabled bool `json:"disabled"`
|
||||
Domain string `json:"domain"`
|
||||
MatchOnly bool `json:"matchOnly"`
|
||||
Disabled bool `json:"disabled"`
|
||||
Domain domain.Domain `json:"domain"`
|
||||
MatchOnly bool `json:"matchOnly"`
|
||||
}
|
||||
|
||||
type mockHostConfigurator struct {
|
||||
@@ -103,18 +104,20 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
config.RouteAll = true
|
||||
}
|
||||
|
||||
for _, domain := range nsConfig.Domains {
|
||||
for _, d := range nsConfig.Domains {
|
||||
d := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.ToLower(dns.Fqdn(domain)),
|
||||
Domain: domain.Domain(d),
|
||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
d := strings.ToLower(dns.Fqdn(customZone.Domain))
|
||||
matchOnly := strings.HasSuffix(d, ipv4ReverseZone) || strings.HasSuffix(d, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||
Domain: domain.Domain(d),
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain.PunycodeString(), "."))
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -41,6 +44,20 @@ const (
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
|
||||
// Network interface DNS registration settings
|
||||
disableDynamicUpdateKey = "DisableDynamicUpdate"
|
||||
registrationEnabledKey = "RegistrationEnabled"
|
||||
maxNumberOfAddressesToRegisterKey = "MaxNumberOfAddressesToRegister"
|
||||
|
||||
// NetBIOS/WINS settings
|
||||
netbtInterfacePath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces`
|
||||
netbiosOptionsKey = "NetbiosOptions"
|
||||
|
||||
// NetBIOS option values: 0 = from DHCP, 1 = enabled, 2 = disabled
|
||||
netbiosFromDHCP = 0
|
||||
netbiosEnabled = 1
|
||||
netbiosDisabled = 2
|
||||
|
||||
// RP_FORCE: Reapply all policies even if no policy change was detected
|
||||
rpForce = 0x1
|
||||
)
|
||||
@@ -67,16 +84,85 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
log.Infof("detected GPO DNS policy configuration, using policy store")
|
||||
}
|
||||
|
||||
return ®istryConfigurator{
|
||||
configurator := ®istryConfigurator{
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := configurator.configureInterface(); err != nil {
|
||||
log.Errorf("failed to configure interface settings: %v", err)
|
||||
}
|
||||
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) configureInterface() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.disableDNSRegistrationForInterface(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("disable DNS registration: %w", err))
|
||||
}
|
||||
|
||||
if err := r.disableWINSForInterface(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("disable WINS: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) disableDNSRegistrationForInterface() error {
|
||||
regKey, err := r.getInterfaceRegistryKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closer(regKey)
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := regKey.SetDWordValue(disableDynamicUpdateKey, 1); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", disableDynamicUpdateKey, err))
|
||||
}
|
||||
|
||||
if err := regKey.SetDWordValue(registrationEnabledKey, 0); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", registrationEnabledKey, err))
|
||||
}
|
||||
|
||||
if err := regKey.SetDWordValue(maxNumberOfAddressesToRegisterKey, 0); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", maxNumberOfAddressesToRegisterKey, err))
|
||||
}
|
||||
|
||||
if merr == nil || len(merr.Errors) == 0 {
|
||||
log.Infof("disabled DNS registration for interface %s", r.guid)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) disableWINSForInterface() error {
|
||||
netbtKeyPath := fmt.Sprintf(`%s\Tcpip_%s`, netbtInterfacePath, r.guid)
|
||||
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
regKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create NetBT interface key %s: %w", netbtKeyPath, err)
|
||||
}
|
||||
}
|
||||
defer closer(regKey)
|
||||
|
||||
// NetbiosOptions: 2 = disabled
|
||||
if err := regKey.SetDWordValue(netbiosOptionsKey, netbiosDisabled); err != nil {
|
||||
return fmt.Errorf("set %s: %w", netbiosOptionsKey, err)
|
||||
}
|
||||
|
||||
log.Infof("disabled WINS/NetBIOS for interface %s", r.guid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
if config.RouteAll {
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
@@ -100,9 +186,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
continue
|
||||
}
|
||||
if !dConf.MatchOnly {
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||
}
|
||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
@@ -119,9 +205,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
}
|
||||
|
||||
if err := r.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
go r.flushDNSCache()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -191,7 +275,25 @@ func (r *registryConfigurator) string() string {
|
||||
return "registry"
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) flushDNSCache() error {
|
||||
func (r *registryConfigurator) registerDNS() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// nolint:misspell
|
||||
cmd := exec.CommandContext(ctx, "ipconfig", "/registerdns")
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to register DNS: %v, output: %s", err, out)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("registered DNS names")
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) flushDNSCache() {
|
||||
r.registerDNS()
|
||||
|
||||
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
@@ -202,13 +304,14 @@ func (r *registryConfigurator) flushDNSCache() error {
|
||||
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||
if ret == 0 {
|
||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||
log.Errorf("DnsFlushResolverCache failed: %v", err)
|
||||
return
|
||||
}
|
||||
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||
log.Errorf("DnsFlushResolverCache failed")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
@@ -263,9 +366,7 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
go r.flushDNSCache()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -62,8 +62,8 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
return fmt.Errorf("method UpdateDNSServer is not implemented")
|
||||
}
|
||||
|
||||
func (m *MockServer) SearchDomains() []string {
|
||||
return make([]string, 0)
|
||||
func (m *MockServer) SearchDomains() domain.List {
|
||||
return make(domain.List, 0)
|
||||
}
|
||||
|
||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||
|
||||
@@ -125,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dConf.Domain)
|
||||
matchDomains = append(matchDomains, "~."+dConf.Domain.PunycodeString())
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
|
||||
}
|
||||
|
||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type notifier struct {
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
searchDomains []string
|
||||
searchDomains domain.List
|
||||
}
|
||||
|
||||
func newNotifier(initialSearchDomains []string) *notifier {
|
||||
sort.Strings(initialSearchDomains)
|
||||
func newNotifier(initialSearchDomains domain.List) *notifier {
|
||||
return ¬ifier{
|
||||
searchDomains: initialSearchDomains,
|
||||
}
|
||||
@@ -27,16 +25,8 @@ func (n *notifier) setListener(listener listener.NetworkChangeListener) {
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *notifier) onNewSearchDomains(searchDomains []string) {
|
||||
sort.Strings(searchDomains)
|
||||
|
||||
if len(n.searchDomains) != len(searchDomains) {
|
||||
n.searchDomains = searchDomains
|
||||
n.notify()
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(n.searchDomains, searchDomains) {
|
||||
func (n *notifier) onNewSearchDomains(searchDomains domain.List) {
|
||||
if searchDomains.Equal(n.searchDomains) {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -44,12 +44,12 @@ type Server interface {
|
||||
DnsIP() string
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
SearchDomains() []string
|
||||
SearchDomains() domain.List
|
||||
ProbeAvailability()
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
domain string
|
||||
domain domain.Domain
|
||||
groups []*nbdns.NameServerGroup
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ type handlerWithStop interface {
|
||||
}
|
||||
|
||||
type handlerWrapper struct {
|
||||
domain string
|
||||
domain domain.Domain
|
||||
handler handlerWithStop
|
||||
priority int
|
||||
}
|
||||
@@ -197,7 +197,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.registerHandler(domains.ToPunycodeList(), handler, priority)
|
||||
s.registerHandler(domains, handler, priority)
|
||||
|
||||
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
||||
for _, domain := range domains {
|
||||
@@ -207,7 +207,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
func (s *DefaultServer) registerHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
log.Debugf("registering handler %s with priority %d", handler, priority)
|
||||
|
||||
for _, domain := range domains {
|
||||
@@ -224,7 +224,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.deregisterHandler(domains.ToPunycodeList(), priority)
|
||||
s.deregisterHandler(domains, priority)
|
||||
for _, domain := range domains {
|
||||
zone := toZone(domain)
|
||||
s.extraDomains[zone]--
|
||||
@@ -235,7 +235,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
func (s *DefaultServer) deregisterHandler(domains domain.List, priority int) {
|
||||
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||
|
||||
for _, domain := range domains {
|
||||
@@ -378,8 +378,8 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) SearchDomains() []string {
|
||||
var searchDomains []string
|
||||
func (s *DefaultServer) SearchDomains() domain.List {
|
||||
var searchDomains domain.List
|
||||
|
||||
for _, dConf := range s.currentConfig.Domains {
|
||||
if dConf.Disabled {
|
||||
@@ -472,24 +472,22 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
|
||||
config := s.currentConfig
|
||||
|
||||
existingDomains := make(map[string]struct{})
|
||||
existingDomains := make(map[domain.Domain]struct{})
|
||||
for _, d := range config.Domains {
|
||||
existingDomains[d.Domain] = struct{}{}
|
||||
}
|
||||
|
||||
// add extra domains only if they're not already in the config
|
||||
for domain := range s.extraDomains {
|
||||
domainStr := domain.PunycodeString()
|
||||
|
||||
if _, exists := existingDomains[domainStr]; !exists {
|
||||
for d := range s.extraDomains {
|
||||
if _, exists := existingDomains[d]; !exists {
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: domainStr,
|
||||
Domain: d,
|
||||
MatchOnly: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("extra match domains: %v", s.extraDomains)
|
||||
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||
@@ -525,7 +523,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
domain: customZone.Domain,
|
||||
domain: domain.Domain(customZone.Domain),
|
||||
handler: s.localResolver,
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
@@ -647,7 +645,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
s.deregisterHandler(domain.List{existing.domain}, existing.priority)
|
||||
existing.handler.Stop()
|
||||
}
|
||||
|
||||
@@ -658,7 +656,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
if update.domain == nbdns.RootZone {
|
||||
containsRootUpdate = true
|
||||
}
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
s.registerHandler(domain.List{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.ID()] = update
|
||||
}
|
||||
|
||||
@@ -687,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
handler dns.Handler,
|
||||
priority int,
|
||||
) (deactivate func(error), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
var removeIndex map[domain.Domain]int
|
||||
deactivate = func(err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -695,20 +693,20 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Info("Temporarily deactivating nameservers group due to timeout")
|
||||
|
||||
removeIndex = make(map[string]int)
|
||||
removeIndex = make(map[domain.Domain]int)
|
||||
for _, domain := range nsGroup.Domains {
|
||||
removeIndex[domain] = -1
|
||||
}
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.RouteAll = false
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, priority)
|
||||
s.deregisterHandler(domain.List{nbdns.RootZone}, priority)
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.Domains {
|
||||
if _, found := removeIndex[item.Domain]; found {
|
||||
s.currentConfig.Domains[i].Disabled = true
|
||||
s.deregisterHandler([]string{item.Domain}, priority)
|
||||
s.deregisterHandler(domain.List{item.Domain}, priority)
|
||||
removeIndex[item.Domain] = i
|
||||
}
|
||||
}
|
||||
@@ -732,12 +730,12 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
for domain, i := range removeIndex {
|
||||
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain {
|
||||
for d, i := range removeIndex {
|
||||
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != d{
|
||||
continue
|
||||
}
|
||||
s.currentConfig.Domains[i].Disabled = false
|
||||
s.registerHandler([]string{domain}, handler, priority)
|
||||
s.registerHandler(domain.List{d}, handler, priority)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
@@ -745,7 +743,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.RouteAll = true
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
s.registerHandler(domain.List{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
s.applyHostConfig()
|
||||
@@ -777,7 +775,7 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
handler.deactivate = func(error) {}
|
||||
handler.reactivate = func() {}
|
||||
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
s.registerHandler(domain.List{nbdns.RootZone}, handler, PriorityDefault)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||
@@ -792,7 +790,7 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||
state := peer.NSGroupState{
|
||||
ID: generateGroupKey(group),
|
||||
Servers: servers,
|
||||
Domains: group.Domains,
|
||||
Domains: group.Domains.ToPunycodeList(),
|
||||
// The probe will determine the state, default enabled
|
||||
Enabled: true,
|
||||
Error: nil,
|
||||
@@ -825,7 +823,7 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||
|
||||
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
||||
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
|
||||
domainMap := make(map[string][]*nbdns.NameServerGroup)
|
||||
domainMap := make(map[domain.Domain][]*nbdns.NameServerGroup)
|
||||
|
||||
for _, group := range nsGroups {
|
||||
if group.Primary {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -46,10 +45,9 @@ func (w *mocWGIface) Name() string {
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() wgaddr.Address {
|
||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||
return wgaddr.Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
IP: netip.MustParseAddr("100.66.100.1"),
|
||||
Network: netip.MustParsePrefix("100.66.100.0/24"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +95,7 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
func generateDummyHandler(domain domain.Domain, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []string
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, getNSHostPort(srv))
|
||||
@@ -152,7 +150,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
Domains: domain.List{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
@@ -184,7 +182,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
name: "New Config Should Succeed",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
@@ -202,7 +200,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
Domains: domain.List{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
@@ -303,8 +301,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
|
||||
domain: domain.Domain(zoneRecords[0].Name),
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
@@ -319,8 +317,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
|
||||
domain: domain.Domain(zoneRecords[0].Name),
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
@@ -464,17 +462,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
||||
if err != nil {
|
||||
t.Errorf("parse CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
@@ -501,7 +492,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
domain: domain.Domain(zoneRecords[0].Name),
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
@@ -533,7 +524,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
Domains: domain.List{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
@@ -599,7 +590,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
|
||||
dnsServer.registerHandler(domain.List{"netbird.cloud"}, dnsServer.localResolver, 1)
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
@@ -659,48 +650,48 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
|
||||
var domainsUpdate string
|
||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
||||
domains := []string{}
|
||||
domains := domain.List{}
|
||||
for _, item := range config.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
domainsUpdate = strings.Join(domains, ",")
|
||||
domainsUpdate = domains.PunycodeString()
|
||||
return nil
|
||||
}
|
||||
|
||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||
Domains: []string{"domain1"},
|
||||
Domains: domain.List{"domain1"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
}, nil, 0)
|
||||
|
||||
deactivate(nil)
|
||||
expected := "domain0,domain2"
|
||||
domains := []string{}
|
||||
expected := "domain0, domain2"
|
||||
domains := domain.List{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got := strings.Join(domains, ",")
|
||||
got := domains.PunycodeString()
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||
}
|
||||
|
||||
reactivate()
|
||||
expected = "domain0,domain1,domain2"
|
||||
domains = []string{}
|
||||
expected = "domain0, domain1, domain2"
|
||||
domains = domain.List{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got = strings.Join(domains, ",")
|
||||
got = domains.PunycodeString()
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||
}
|
||||
@@ -868,7 +859,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"google.com"},
|
||||
Domains: domain.List{"google.com"},
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
@@ -1123,7 +1114,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
name string
|
||||
initialHandlers registeredHandlerMap
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[HandlerID]domain
|
||||
expectedHandlers map[string]domain.Domain // map[HandlerID]domain
|
||||
description string
|
||||
}{
|
||||
{
|
||||
@@ -1139,7 +1130,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-group2": "example.com",
|
||||
},
|
||||
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
||||
@@ -1157,7 +1148,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-group1": "example.com",
|
||||
},
|
||||
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
||||
@@ -1190,7 +1181,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-group3": "example.com",
|
||||
@@ -1225,7 +1216,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityMatchDomain - 2,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-group3": "example.com",
|
||||
@@ -1245,7 +1236,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-root2": ".",
|
||||
},
|
||||
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
||||
@@ -1262,7 +1253,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-root1": ".",
|
||||
},
|
||||
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
||||
@@ -1293,7 +1284,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-root1": ".",
|
||||
"upstream-root2": ".",
|
||||
"upstream-root3": ".",
|
||||
@@ -1326,7 +1317,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityDefault - 2,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-root1": ".",
|
||||
"upstream-root2": ".",
|
||||
"upstream-root3": ".",
|
||||
@@ -1353,7 +1344,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-other": "other.com",
|
||||
},
|
||||
@@ -1392,7 +1383,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
expectedHandlers: map[string]domain.Domain{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-other": "other.com",
|
||||
@@ -1448,7 +1439,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
for _, muxEntry := range server.dnsMuxMap {
|
||||
if chainEntry.Handler == muxEntry.handler &&
|
||||
chainEntry.Priority == muxEntry.priority &&
|
||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||
chainEntry.Pattern.PunycodeString() == dns.Fqdn(muxEntry.domain.PunycodeString()) {
|
||||
foundInMux = true
|
||||
break
|
||||
}
|
||||
@@ -1467,8 +1458,8 @@ func TestExtraDomains(t *testing.T) {
|
||||
registerDomains []domain.List
|
||||
deregisterDomains []domain.List
|
||||
finalConfig nbdns.Config
|
||||
expectedDomains []string
|
||||
expectedMatchOnly []string
|
||||
expectedDomains domain.List
|
||||
expectedMatchOnly domain.List
|
||||
applyHostConfigCall int
|
||||
}{
|
||||
{
|
||||
@@ -1482,12 +1473,12 @@ func TestExtraDomains(t *testing.T) {
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
expectedDomains: domain.List{
|
||||
"config.example.com.",
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
expectedMatchOnly: domain.List{
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
@@ -1504,12 +1495,12 @@ func TestExtraDomains(t *testing.T) {
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
expectedDomains: domain.List{
|
||||
"config.example.com.",
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
expectedMatchOnly: domain.List{
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
@@ -1527,12 +1518,12 @@ func TestExtraDomains(t *testing.T) {
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "overlap.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
expectedDomains: domain.List{
|
||||
"config.example.com.",
|
||||
"overlap.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
expectedMatchOnly: domain.List{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
@@ -1552,12 +1543,12 @@ func TestExtraDomains(t *testing.T) {
|
||||
deregisterDomains: []domain.List{
|
||||
{"extra1.example.com", "extra3.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
expectedDomains: domain.List{
|
||||
"config.example.com.",
|
||||
"extra2.example.com.",
|
||||
"extra4.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
expectedMatchOnly: domain.List{
|
||||
"extra2.example.com.",
|
||||
"extra4.example.com.",
|
||||
},
|
||||
@@ -1578,13 +1569,13 @@ func TestExtraDomains(t *testing.T) {
|
||||
deregisterDomains: []domain.List{
|
||||
{"duplicate.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
expectedDomains: domain.List{
|
||||
"config.example.com.",
|
||||
"extra.example.com.",
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
expectedMatchOnly: domain.List{
|
||||
"extra.example.com.",
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
@@ -1609,13 +1600,13 @@ func TestExtraDomains(t *testing.T) {
|
||||
{Domain: "newconfig.example.com"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
expectedDomains: domain.List{
|
||||
"config.example.com.",
|
||||
"newconfig.example.com.",
|
||||
"extra.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
expectedMatchOnly: domain.List{
|
||||
"extra.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
@@ -1636,12 +1627,12 @@ func TestExtraDomains(t *testing.T) {
|
||||
deregisterDomains: []domain.List{
|
||||
{"protected.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
expectedDomains: domain.List{
|
||||
"extra.example.com.",
|
||||
"config.example.com.",
|
||||
"protected.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
expectedMatchOnly: domain.List{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
@@ -1652,7 +1643,7 @@ func TestExtraDomains(t *testing.T) {
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
||||
Domains: domain.List{"ns.example.com", "overlap.ns.example.com"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
@@ -1666,12 +1657,12 @@ func TestExtraDomains(t *testing.T) {
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "overlap.ns.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
expectedDomains: domain.List{
|
||||
"ns.example.com.",
|
||||
"overlap.ns.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
expectedMatchOnly: domain.List{
|
||||
"ns.example.com.",
|
||||
"overlap.ns.example.com.",
|
||||
"extra.example.com.",
|
||||
@@ -1742,8 +1733,8 @@ func TestExtraDomains(t *testing.T) {
|
||||
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
||||
|
||||
// Check all expected domains are present
|
||||
domainMap := make(map[string]bool)
|
||||
matchOnlyMap := make(map[string]bool)
|
||||
domainMap := make(map[domain.Domain]bool)
|
||||
matchOnlyMap := make(map[domain.Domain]bool)
|
||||
|
||||
for _, d := range lastConfig.Domains {
|
||||
domainMap[d.Domain] = true
|
||||
@@ -1860,12 +1851,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||
err := server.applyConfiguration(initialConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var domains []string
|
||||
var domains domain.List
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
assert.Contains(t, domains, "config.example.com.")
|
||||
assert.Contains(t, domains, "extra.example.com.")
|
||||
assert.Contains(t, domains, domain.Domain("config.example.com."))
|
||||
assert.Contains(t, domains, domain.Domain("extra.example.com."))
|
||||
|
||||
// Now apply a new configuration with overlapping domain
|
||||
updatedConfig := nbdns.Config{
|
||||
@@ -1879,7 +1870,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify both domains are in config, but no duplicates
|
||||
domains = []string{}
|
||||
domains = domain.List{}
|
||||
matchOnlyCount := 0
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
@@ -1888,12 +1879,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, domains, "config.example.com.")
|
||||
assert.Contains(t, domains, "extra.example.com.")
|
||||
assert.Contains(t, domains, domain.Domain("config.example.com."))
|
||||
assert.Contains(t, domains, domain.Domain("extra.example.com."))
|
||||
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
||||
|
||||
// Extra domain should no longer be marked as match-only when in config
|
||||
matchOnlyDomain := ""
|
||||
var matchOnlyDomain domain.Domain
|
||||
for _, d := range capturedConfig.Domains {
|
||||
if d.Domain == "extra.example.com." && d.MatchOnly {
|
||||
matchOnlyDomain = d.Domain
|
||||
@@ -1946,10 +1937,10 @@ func TestDomainCaseHandling(t *testing.T) {
|
||||
err := server.applyConfiguration(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var domains []string
|
||||
var domains domain.List
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||
assert.Contains(t, domains, domain.Domain("config.example.com."), "Mixed case domain should be normalized and pre.sent")
|
||||
assert.Contains(t, domains, domain.Domain("mixed.example.com."), "Mixed case domain should be normalized and present")
|
||||
}
|
||||
|
||||
@@ -24,11 +24,15 @@ type ServiceViaMemory struct {
|
||||
}
|
||||
|
||||
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
s := &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
|
||||
runtimeIP: lastIP.String(),
|
||||
runtimePort: defaultPort,
|
||||
}
|
||||
return s
|
||||
@@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
if s.wgInterface.Address().IP.Is6() {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -117,15 +117,15 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
continue
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: dConf.Domain,
|
||||
Domain: dConf.Domain.PunycodeString(),
|
||||
MatchOnly: dConf.MatchOnly,
|
||||
})
|
||||
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, dConf.Domain)
|
||||
matchDomains = append(matchDomains, dConf.Domain.PunycodeString())
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []string
|
||||
domain string
|
||||
domain domain.Domain
|
||||
disabled bool
|
||||
failsCount atomic.Int32
|
||||
successCount atomic.Int32
|
||||
@@ -62,7 +63,7 @@ type upstreamResolverBase struct {
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain domain.Domain) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &upstreamResolverBase{
|
||||
|
||||
@@ -3,12 +3,14 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -23,11 +25,11 @@ type upstreamResolver struct {
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ net.IP,
|
||||
_ *net.IPNet,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
domain domain.Domain,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
c := &upstreamResolver{
|
||||
|
||||
@@ -4,12 +4,13 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
@@ -19,11 +20,11 @@ type upstreamResolver struct {
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ net.IP,
|
||||
_ *net.IPNet,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
domain domain.Domain,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
nonIOS := &upstreamResolver{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -14,23 +15,24 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
*upstreamResolverBase
|
||||
lIP net.IP
|
||||
lNet *net.IPNet
|
||||
lIP netip.Addr
|
||||
lNet netip.Prefix
|
||||
interfaceName string
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
interfaceName string,
|
||||
ip net.IP,
|
||||
net *net.IPNet,
|
||||
ip netip.Addr,
|
||||
net netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
domain domain.Domain,
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
|
||||
@@ -58,8 +60,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
client.DialTimeout = timeout
|
||||
|
||||
upstreamIP := net.ParseIP(upstreamHost)
|
||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
||||
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
||||
}
|
||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
@@ -73,7 +78,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// This method is needed for iOS
|
||||
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
@@ -82,7 +87,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
|
||||
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: &net.UDPAddr{
|
||||
IP: ip,
|
||||
IP: ip.AsSlice(),
|
||||
Port: 0, // Let the OS pick a free port
|
||||
},
|
||||
Timeout: dialTimeout,
|
||||
|
||||
@@ -2,7 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||
resolver.upstreamServers = testCase.InputServers
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
|
||||
@@ -121,8 +121,8 @@ type EngineConfig struct {
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
|
||||
BlockLANAccess bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
@@ -359,6 +359,7 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("new wg interface: %w", err)
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
e.statusRecorder.SetWgIface(wgIface)
|
||||
|
||||
// start flow manager right after interface creation
|
||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||
@@ -380,7 +381,6 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
e.stateManager.Start()
|
||||
|
||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||
@@ -431,7 +431,8 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
// if inbound conns are blocked there is no need to create the ACL manager
|
||||
if e.firewall != nil && !e.config.BlockInbound {
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
}
|
||||
|
||||
@@ -487,11 +488,9 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if e.firewall.IsServerRouteSupported() {
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
|
||||
if e.config.BlockLANAccess {
|
||||
@@ -525,6 +524,11 @@ func (e *Engine) initFirewall() error {
|
||||
}
|
||||
|
||||
func (e *Engine) blockLanAccess() {
|
||||
if e.config.BlockInbound {
|
||||
// no need to set up extra deny rules if inbound is already blocked in general
|
||||
return
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// TODO: keep this updated
|
||||
@@ -796,56 +800,58 @@ func isNil(server nbssh.Server) bool {
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Infof("SSH server is disabled because inbound connections are blocked")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Warnf("running SSH server is not permitted")
|
||||
log.Info("SSH server is not enabled")
|
||||
return nil
|
||||
} else {
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
var err error
|
||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
go func() {
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.sshServer = nil
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
var err error
|
||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
go func() {
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.sshServer = nil
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
@@ -988,12 +994,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
log.Errorf("failed to update routes: %v", err)
|
||||
}
|
||||
|
||||
if e.acl != nil {
|
||||
@@ -1055,15 +1070,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
|
||||
e.connMgr.SetExcludeList(excludedLazyPeers)
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
@@ -1098,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
Network: prefix.Masked(),
|
||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||
NetID: route.NetID(protoRoute.NetID),
|
||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||
@@ -1132,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
return entries
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
@@ -1159,7 +1165,7 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.C
|
||||
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
|
||||
dnsNSGroup := &nbdns.NameServerGroup{
|
||||
Primary: nsGroup.GetPrimary(),
|
||||
Domains: nsGroup.GetDomains(),
|
||||
Domains: domain.FromPunycodeList(nsGroup.GetDomains()),
|
||||
SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(),
|
||||
}
|
||||
for _, ns := range nsGroup.GetNameServers() {
|
||||
@@ -1447,6 +1453,7 @@ func (e *Engine) close() {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
}
|
||||
e.wgInterface = nil
|
||||
e.statusRecorder.SetWgIface(nil)
|
||||
}
|
||||
|
||||
if !isNil(e.sshServer) {
|
||||
@@ -1671,7 +1678,7 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||
return append(
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)...,
|
||||
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1784,9 +1791,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
}
|
||||
|
||||
// GetWgAddr returns the wireguard address
|
||||
func (e *Engine) GetWgAddr() net.IP {
|
||||
func (e *Engine) GetWgAddr() netip.Addr {
|
||||
if e.wgInterface == nil {
|
||||
return nil
|
||||
return netip.Addr{}
|
||||
}
|
||||
return e.wgInterface.Address().IP
|
||||
}
|
||||
@@ -1796,6 +1803,10 @@ func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
) {
|
||||
if e.config.DisableServerRoutes {
|
||||
return
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1851,12 +1862,7 @@ func (e *Engine) Address() (netip.Addr, error) {
|
||||
return netip.Addr{}, errors.New("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
addr := e.wgInterface.Address()
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
|
||||
}
|
||||
return ip.Unmap(), nil
|
||||
return e.wgInterface.Address().IP, nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||
|
||||
@@ -44,6 +44,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
mgmt "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -77,7 +78,7 @@ var (
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains domain.List) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() wgaddr.Address
|
||||
@@ -99,6 +100,10 @@ type MockWGIface struct {
|
||||
GetNetFunc func() *netstack.Net
|
||||
}
|
||||
|
||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return m.GetInterfaceGUIDStringFunc()
|
||||
}
|
||||
@@ -107,7 +112,7 @@ func (m *MockWGIface) Create() error {
|
||||
return m.CreateFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains domain.List) error {
|
||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
||||
}
|
||||
|
||||
@@ -371,11 +376,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
|
||||
@@ -14,11 +14,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type wgIfaceBase interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains domain.List) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
@@ -37,4 +38,5 @@ type wgIfaceBase interface {
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetNet() *netstack.Net
|
||||
FullStats() (*configurer.Stats, error)
|
||||
}
|
||||
|
||||
@@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
|
||||
|
||||
// fallback if mark rules are not in place
|
||||
wgnet := c.iface.Address().Network
|
||||
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
|
||||
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
|
||||
}
|
||||
|
||||
// mapRxPackets maps packet counts to RX based on flow direction
|
||||
@@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
|
||||
// fallback if marks are not set
|
||||
wgaddr := c.iface.Address().IP
|
||||
wgnetwork := c.iface.Address().Network
|
||||
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
||||
|
||||
switch {
|
||||
case wgaddr.Equal(src):
|
||||
case wgaddr == srcIP:
|
||||
return nftypes.Egress
|
||||
case wgaddr.Equal(dst):
|
||||
case wgaddr == dstIP:
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(src):
|
||||
case wgnetwork.Contains(srcIP):
|
||||
// netbird network -> resource network
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(dst):
|
||||
case wgnetwork.Contains(dstIP):
|
||||
// resource network -> netbird network
|
||||
return nftypes.Egress
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -23,17 +23,16 @@ type Logger struct {
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgIfaceIPNet net.IPNet
|
||||
wgIfaceNet netip.Prefix
|
||||
dnsCollection atomic.Bool
|
||||
exitNodeCollection atomic.Bool
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
|
||||
return &Logger{
|
||||
statusRecorder: statusRecorder,
|
||||
wgIfaceIPNet: wgIfaceIPNet,
|
||||
wgIfaceNet: wgIfaceIPNet,
|
||||
Store: store.NewMemoryStore(),
|
||||
}
|
||||
}
|
||||
@@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
|
||||
var isSrcExitNode bool
|
||||
var isDestExitNode bool
|
||||
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
||||
if !l.wgIfaceNet.Contains(event.SourceIP) {
|
||||
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||
}
|
||||
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
||||
if !l.wgIfaceNet.Contains(event.DestIP) {
|
||||
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(nil, net.IPNet{})
|
||||
logger := logger.New(nil, netip.Prefix{})
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -34,11 +34,11 @@ type Manager struct {
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
var ipNet net.IPNet
|
||||
var prefix netip.Prefix
|
||||
if iface != nil {
|
||||
ipNet = *iface.Address().Network
|
||||
prefix = iface.Address().Network
|
||||
}
|
||||
flowLogger := logger.New(statusRecorder, ipNet)
|
||||
flowLogger := logger.New(statusRecorder, prefix)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package netflow
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
||||
func TestManager_Update(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
Network: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
@@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) {
|
||||
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
Network: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
@@ -31,6 +33,10 @@ type ResolvedDomainInfo struct {
|
||||
ParentDomain domain.Domain
|
||||
}
|
||||
|
||||
type WGIfaceStatus interface {
|
||||
FullStats() (*configurer.Stats, error)
|
||||
}
|
||||
|
||||
type EventListener interface {
|
||||
OnEvent(event *proto.SystemEvent)
|
||||
}
|
||||
@@ -146,11 +152,31 @@ type FullStatus struct {
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
type StatusChangeSubscription struct {
|
||||
peerID string
|
||||
id string
|
||||
eventsChan chan struct{}
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func newStatusChangeSubscription(ctx context.Context, peerID string) *StatusChangeSubscription {
|
||||
return &StatusChangeSubscription{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
id: uuid.New().String(),
|
||||
eventsChan: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatusChangeSubscription) Events() chan struct{} {
|
||||
return s.eventsChan
|
||||
}
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
@@ -181,13 +207,14 @@ type Status struct {
|
||||
ingressGwMgr *ingressgw.Manager
|
||||
|
||||
routeIDLookup routeIDLookup
|
||||
wgIface WGIfaceStatus
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
offlinePeers: make([]State, 0),
|
||||
@@ -289,11 +316,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
if receivedState.IP != "" {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
oldState := peerState.ConnStatus
|
||||
|
||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||
peerState.ConnStatus = receivedState.ConnStatus
|
||||
@@ -309,11 +332,14 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if skipNotification {
|
||||
return nil
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
// when we close the connection we will not notify the router manager
|
||||
if receivedState.ConnStatus == StatusIdle {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -380,11 +406,8 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
if receivedState.IP != "" {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
oldState := peerState.ConnStatus
|
||||
oldIsRelayed := peerState.Relayed
|
||||
|
||||
peerState.ConnStatus = receivedState.ConnStatus
|
||||
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
||||
@@ -397,12 +420,13 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if skipNotification {
|
||||
return nil
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -415,7 +439,8 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
oldState := peerState.ConnStatus
|
||||
oldIsRelayed := peerState.Relayed
|
||||
|
||||
peerState.ConnStatus = receivedState.ConnStatus
|
||||
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
||||
@@ -425,12 +450,13 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if skipNotification {
|
||||
return nil
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -443,7 +469,8 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
oldState := peerState.ConnStatus
|
||||
oldIsRelayed := peerState.Relayed
|
||||
|
||||
peerState.ConnStatus = receivedState.ConnStatus
|
||||
peerState.Relayed = receivedState.Relayed
|
||||
@@ -452,12 +479,13 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if skipNotification {
|
||||
return nil
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -470,7 +498,8 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
oldState := peerState.ConnStatus
|
||||
oldIsRelayed := peerState.Relayed
|
||||
|
||||
peerState.ConnStatus = receivedState.ConnStatus
|
||||
peerState.Relayed = receivedState.Relayed
|
||||
@@ -482,12 +511,13 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if skipNotification {
|
||||
return nil
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -510,17 +540,12 @@ func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGSt
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
|
||||
switch {
|
||||
case receivedConnStatus == StatusConnecting:
|
||||
return true
|
||||
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusConnecting:
|
||||
return true
|
||||
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusIdle:
|
||||
return curr.IP != ""
|
||||
default:
|
||||
return false
|
||||
}
|
||||
func hasStatusOrRelayedChange(oldConnStatus, newConnStatus ConnStatus, oldRelayed, newRelayed bool) bool {
|
||||
return oldRelayed != newRelayed || hasConnStatusChanged(newConnStatus, oldConnStatus)
|
||||
}
|
||||
|
||||
func hasConnStatusChanged(oldStatus, newStatus ConnStatus) bool {
|
||||
return newStatus != oldStatus
|
||||
}
|
||||
|
||||
// UpdatePeerFQDN update peer's state fqdn only
|
||||
@@ -553,19 +578,41 @@ func (d *Status) FinishPeerListModifications() {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||
func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
ch, found := d.changeNotify[peer]
|
||||
if found {
|
||||
return ch
|
||||
sub := newStatusChangeSubscription(ctx, peerID)
|
||||
if _, ok := d.changeNotify[peerID]; !ok {
|
||||
d.changeNotify[peerID] = make(map[string]*StatusChangeSubscription)
|
||||
}
|
||||
d.changeNotify[peerID][sub.id] = sub
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscription) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if subscription == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
return ch
|
||||
channels, ok := d.changeNotify[subscription.peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sub, exists := channels[subscription.id]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
delete(channels, subscription.id)
|
||||
if len(channels) == 0 {
|
||||
delete(d.changeNotify, sub.peerID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetLocalPeerState returns the local peer state
|
||||
@@ -940,13 +987,20 @@ func (d *Status) onConnectionChanged() {
|
||||
|
||||
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
ch, found := d.changeNotify[peerID]
|
||||
if !found {
|
||||
subs, ok := d.changeNotify[peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(d.changeNotify, peerID)
|
||||
for _, sub := range subs {
|
||||
// block the write because we do not want to miss notification
|
||||
// must have to be sure we will run the GetPeerState() on separated thread
|
||||
go func() {
|
||||
select {
|
||||
case sub.eventsChan <- struct{}{}:
|
||||
case <-sub.ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
@@ -1030,6 +1084,23 @@ func (d *Status) GetEventHistory() []*proto.SystemEvent {
|
||||
return d.eventQueue.GetAll()
|
||||
}
|
||||
|
||||
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.wgIface = wgInterface
|
||||
}
|
||||
|
||||
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
if d.wgIface == nil {
|
||||
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
||||
}
|
||||
|
||||
return d.wgIface.FullStats()
|
||||
}
|
||||
|
||||
type EventQueue struct {
|
||||
maxSize int
|
||||
events []*proto.SystemEvent
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -42,16 +44,16 @@ func TestGetPeer(t *testing.T) {
|
||||
func TestUpdatePeerState(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "10.10.10.10"
|
||||
fqdn := "peer-a.netbird.local"
|
||||
status := NewRecorder("https://mgm")
|
||||
_ = status.AddPeer(key, fqdn, ip)
|
||||
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
PubKey: key,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatus: StatusConnecting,
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerState(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
@@ -83,25 +85,27 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "10.10.10.10"
|
||||
status := NewRecorder("https://mgm")
|
||||
_ = status.AddPeer(key, "abc.netbird", ip)
|
||||
|
||||
sub := status.SubscribeToPeerStateChanges(context.Background(), key)
|
||||
assert.NotNil(t, sub, "channel shouldn't be nil")
|
||||
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
PubKey: key,
|
||||
ConnStatus: StatusConnecting,
|
||||
Relayed: false,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
|
||||
ch := status.GetPeerStateChangeNotifier(key)
|
||||
assert.NotNil(t, ch, "channel shouldn't be nil")
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
t.Errorf("channel wasn't closed after update")
|
||||
case <-sub.eventsChan:
|
||||
case <-timeoutCtx.Done():
|
||||
t.Errorf("timed out waiting for event")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ func ProbeAll(
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, uri := range relays {
|
||||
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package routemanager
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,10 +7,8 @@ import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
@@ -36,6 +34,7 @@ const (
|
||||
reasonRouteUpdate
|
||||
reasonPeerUpdate
|
||||
reasonShutdown
|
||||
reasonHA
|
||||
)
|
||||
|
||||
type routerPeerStatus struct {
|
||||
@@ -44,9 +43,9 @@ type routerPeerStatus struct {
|
||||
latency time.Duration
|
||||
}
|
||||
|
||||
type routesUpdate struct {
|
||||
updateSerial uint64
|
||||
routes []*route.Route
|
||||
type RoutesUpdate struct {
|
||||
UpdateSerial uint64
|
||||
Routes []*route.Route
|
||||
}
|
||||
|
||||
// RouteHandler defines the interface for handling routes
|
||||
@@ -58,64 +57,54 @@ type RouteHandler interface {
|
||||
RemoveAllowedIPs() error
|
||||
}
|
||||
|
||||
type clientNetwork struct {
|
||||
type WatcherConfig struct {
|
||||
Context context.Context
|
||||
DNSRouteInterval time.Duration
|
||||
WGInterface iface.WGIface
|
||||
StatusRecorder *peer.Status
|
||||
Route *route.Route
|
||||
Handler RouteHandler
|
||||
}
|
||||
|
||||
// Watcher watches route and peer changes and updates allowed IPs accordingly.
|
||||
// Once stopped, it cannot be reused.
|
||||
type Watcher struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface iface.WGIface
|
||||
routes map[route.ID]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
routeUpdate chan RoutesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
routePeersNotifiers map[string]chan struct{}
|
||||
routePeersNotifiers map[string]chan struct{} // map of peer key to channel for peer state changes
|
||||
currentChosen *route.Route
|
||||
handler RouteHandler
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(
|
||||
ctx context.Context,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface iface.WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
useNewDNSRoute bool,
|
||||
) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
func NewWatcher(config WatcherConfig) *Watcher {
|
||||
ctx, cancel := context.WithCancel(config.Context)
|
||||
|
||||
client := &clientNetwork{
|
||||
client := &Watcher{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: config.StatusRecorder,
|
||||
wgInterface: config.WGInterface,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
routeUpdate: make(chan RoutesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
handler: handlerFromRoute(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
dnsRouteInterval,
|
||||
statusRecorder,
|
||||
wgInterface,
|
||||
dnsServer,
|
||||
peerStore,
|
||||
useNewDNSRoute,
|
||||
),
|
||||
handler: config.Handler,
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||
func (w *Watcher) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||
routePeerStatuses := make(map[route.ID]routerPeerStatus)
|
||||
for _, r := range c.routes {
|
||||
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
||||
for _, r := range w.routes {
|
||||
peerStatus, err := w.statusRecorder.GetPeer(r.Peer)
|
||||
if err != nil {
|
||||
log.Debugf("couldn't fetch peer state: %v", err)
|
||||
log.Debugf("couldn't fetch peer state %v: %v", r.Peer, err)
|
||||
continue
|
||||
}
|
||||
routePeerStatuses[r.ID] = routerPeerStatus{
|
||||
@@ -128,7 +117,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||
}
|
||||
|
||||
// getBestRouteFromStatuses determines the most optimal route from the available routes
|
||||
// within a clientNetwork, taking into account peer connection status, route metrics, and
|
||||
// within a Watcher, taking into account peer connection status, route metrics, and
|
||||
// preference for non-relayed and direct connections.
|
||||
//
|
||||
// It follows these prioritization rules:
|
||||
@@ -140,17 +129,17 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
//
|
||||
// It returns the ID of the selected optimal route.
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||
chosen := route.ID("")
|
||||
func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||
var chosen route.ID
|
||||
chosenScore := float64(0)
|
||||
currScore := float64(0)
|
||||
|
||||
currID := route.ID("")
|
||||
if c.currentChosen != nil {
|
||||
currID = c.currentChosen.ID
|
||||
var currID route.ID
|
||||
if w.currentChosen != nil {
|
||||
currID = w.currentChosen.ID
|
||||
}
|
||||
|
||||
for _, r := range c.routes {
|
||||
for _, r := range w.routes {
|
||||
tempScore := float64(0)
|
||||
peerStatus, found := routePeerStatuses[r.ID]
|
||||
if !found || !peerStatus.connected {
|
||||
@@ -167,7 +156,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
if peerStatus.latency != 0 {
|
||||
latency = peerStatus.latency
|
||||
} else {
|
||||
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
|
||||
log.Tracef("peer %s has 0 latency, range %s", r.Peer, w.handler)
|
||||
}
|
||||
|
||||
// avoid negative tempScore on the higher latency calculation
|
||||
@@ -197,149 +186,145 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
|
||||
chosenID := chosen
|
||||
if chosen == "" {
|
||||
chosenID = "<none>"
|
||||
}
|
||||
currentID := currID
|
||||
if currID == "" {
|
||||
currentID = "<none>"
|
||||
}
|
||||
|
||||
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosenID, chosenScore, currentID, currScore)
|
||||
|
||||
switch {
|
||||
case chosen == "":
|
||||
var peers []string
|
||||
for _, r := range c.routes {
|
||||
for _, r := range w.routes {
|
||||
peers = append(peers, r.Peer)
|
||||
}
|
||||
|
||||
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
|
||||
log.Infof("network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", w.handler, peers)
|
||||
case chosen != currID:
|
||||
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
|
||||
if currScore != 0 && currScore+0.01 > chosenScore {
|
||||
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||
log.Debugf("keeping current routing peer %s for [%v]: the score difference with latency is less than 0.01(10ms): current: %f, new: %f",
|
||||
w.currentChosen.Peer, w.handler, currScore, chosenScore)
|
||||
return currID
|
||||
}
|
||||
var p string
|
||||
if rt := c.routes[chosen]; rt != nil {
|
||||
if rt := w.routes[chosen]; rt != nil {
|
||||
p = rt.Peer
|
||||
}
|
||||
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
|
||||
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, w.handler)
|
||||
}
|
||||
|
||||
return chosen
|
||||
}
|
||||
|
||||
func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
|
||||
func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
|
||||
subscription := w.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey)
|
||||
defer w.statusRecorder.UnsubscribePeerStateChanges(subscription)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-closer:
|
||||
return
|
||||
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil || state.ConnStatus == peer.StatusConnecting {
|
||||
continue
|
||||
}
|
||||
case <-subscription.Events():
|
||||
peerStateUpdate <- struct{}{}
|
||||
log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus)
|
||||
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
for _, r := range c.routes {
|
||||
_, found := c.routePeersNotifiers[r.Peer]
|
||||
if found {
|
||||
func (w *Watcher) startNewPeerStatusWatchers() {
|
||||
for _, r := range w.routes {
|
||||
if _, found := w.routePeersNotifiers[r.Peer]; found {
|
||||
continue
|
||||
}
|
||||
|
||||
closerChan := make(chan struct{})
|
||||
c.routePeersNotifiers[r.Peer] = closerChan
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
|
||||
w.routePeersNotifiers[r.Peer] = closerChan
|
||||
go w.watchPeerStatusChanges(w.ctx, r.Peer, w.peerStateUpdate, closerChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
||||
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
|
||||
// addAllowedIPs adds the allowed IPs for the current chosen route to the handler.
|
||||
func (w *Watcher) addAllowedIPs(route *route.Route) error {
|
||||
if err := w.handler.AddAllowedIPs(route.Peer); err != nil {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", route.Peer, err)
|
||||
}
|
||||
|
||||
if err := w.statusRecorder.AddPeerStateRoute(route.Peer, w.handler.String(), route.GetResourceID()); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
}
|
||||
w.connectEvent(route)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error {
|
||||
if c.currentChosen == nil {
|
||||
return nil
|
||||
func (w *Watcher) removeAllowedIPs(route *route.Route, rsn reason) error {
|
||||
if err := w.statusRecorder.RemovePeerStateRoute(route.Peer, w.handler.String()); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
||||
if err := w.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
}
|
||||
|
||||
c.disconnectEvent(rsn)
|
||||
w.disconnectEvent(route, rsn)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error {
|
||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||
func (w *Watcher) recalculateRoutes(rsn reason) error {
|
||||
routerPeerStatuses := w.getRouterPeerStatuses()
|
||||
|
||||
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
newChosenID := w.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
|
||||
// If no route is chosen, remove the route from the peer and system
|
||||
// If no route is chosen, remove the route from the peer
|
||||
if newChosenID == "" {
|
||||
if err := c.removeRouteFromPeerAndSystem(rsn); err != nil {
|
||||
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
||||
if w.currentChosen == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.currentChosen = nil
|
||||
if err := w.removeAllowedIPs(w.currentChosen, rsn); err != nil {
|
||||
return fmt.Errorf("remove obsolete: %w", err)
|
||||
}
|
||||
|
||||
w.currentChosen = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the chosen route is the same as the current route, do nothing
|
||||
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
|
||||
c.currentChosen.Equal(c.routes[newChosenID]) {
|
||||
if w.currentChosen != nil && w.currentChosen.ID == newChosenID &&
|
||||
w.currentChosen.Equal(w.routes[newChosenID]) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var isNew bool
|
||||
if c.currentChosen == nil {
|
||||
// If they were not previously assigned to another peer, add routes to the system first
|
||||
if err := c.handler.AddRoute(c.ctx); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
isNew = true
|
||||
} else {
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
// If the chosen route was assigned to a different peer, remove the allowed IPs first
|
||||
if isNew := w.currentChosen == nil; !isNew {
|
||||
if err := w.removeAllowedIPs(w.currentChosen, reasonHA); err != nil {
|
||||
return fmt.Errorf("remove old: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.currentChosen = c.routes[newChosenID]
|
||||
|
||||
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
newChosenRoute := w.routes[newChosenID]
|
||||
if err := w.addAllowedIPs(newChosenRoute); err != nil {
|
||||
return fmt.Errorf("add new: %w", err)
|
||||
}
|
||||
|
||||
if isNew {
|
||||
c.connectEvent()
|
||||
}
|
||||
w.currentChosen = newChosenRoute
|
||||
|
||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
|
||||
if err != nil {
|
||||
return fmt.Errorf("add peer state route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) connectEvent() {
|
||||
func (w *Watcher) connectEvent(route *route.Route) {
|
||||
var defaultRoute bool
|
||||
for _, r := range c.routes {
|
||||
for _, r := range w.routes {
|
||||
if r.Network.Bits() == 0 {
|
||||
defaultRoute = true
|
||||
break
|
||||
@@ -351,13 +336,13 @@ func (c *clientNetwork) connectEvent() {
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
"network": c.handler.String(),
|
||||
"network": w.handler.String(),
|
||||
}
|
||||
if c.currentChosen != nil {
|
||||
meta["id"] = string(c.currentChosen.NetID)
|
||||
meta["peer"] = c.currentChosen.Peer
|
||||
if route != nil {
|
||||
meta["id"] = string(route.NetID)
|
||||
meta["peer"] = route.Peer
|
||||
}
|
||||
c.statusRecorder.PublishEvent(
|
||||
w.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_NETWORK,
|
||||
"Default route added",
|
||||
@@ -366,9 +351,9 @@ func (c *clientNetwork) connectEvent() {
|
||||
)
|
||||
}
|
||||
|
||||
func (c *clientNetwork) disconnectEvent(rsn reason) {
|
||||
func (w *Watcher) disconnectEvent(route *route.Route, rsn reason) {
|
||||
var defaultRoute bool
|
||||
for _, r := range c.routes {
|
||||
for _, r := range w.routes {
|
||||
if r.Network.Bits() == 0 {
|
||||
defaultRoute = true
|
||||
break
|
||||
@@ -384,11 +369,11 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
|
||||
var userMessage string
|
||||
meta := make(map[string]string)
|
||||
|
||||
if c.currentChosen != nil {
|
||||
meta["id"] = string(c.currentChosen.NetID)
|
||||
meta["peer"] = c.currentChosen.Peer
|
||||
if route != nil {
|
||||
meta["id"] = string(route.NetID)
|
||||
meta["peer"] = route.Peer
|
||||
}
|
||||
meta["network"] = c.handler.String()
|
||||
meta["network"] = w.handler.String()
|
||||
switch rsn {
|
||||
case reasonShutdown:
|
||||
severity = proto.SystemEvent_INFO
|
||||
@@ -401,13 +386,17 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
|
||||
severity = proto.SystemEvent_WARNING
|
||||
message = "Default route disconnected due to peer unreachability"
|
||||
userMessage = "Exit node connection lost. Your internet access might be affected."
|
||||
case reasonHA:
|
||||
severity = proto.SystemEvent_INFO
|
||||
message = "Default route disconnected due to high availability change"
|
||||
userMessage = "Exit node disconnected due to high availability change."
|
||||
default:
|
||||
severity = proto.SystemEvent_ERROR
|
||||
message = "Default route disconnected for unknown reasons"
|
||||
userMessage = "Exit node disconnected for unknown reasons."
|
||||
}
|
||||
|
||||
c.statusRecorder.PublishEvent(
|
||||
w.statusRecorder.PublishEvent(
|
||||
severity,
|
||||
proto.SystemEvent_NETWORK,
|
||||
message,
|
||||
@@ -416,86 +405,101 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
|
||||
)
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
func (w *Watcher) SendUpdate(update RoutesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
select {
|
||||
case w.routeUpdate <- update:
|
||||
case <-w.ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
|
||||
func (w *Watcher) classifyUpdate(update RoutesUpdate) bool {
|
||||
isUpdateMapDifferent := false
|
||||
updateMap := make(map[route.ID]*route.Route)
|
||||
|
||||
for _, r := range update.routes {
|
||||
for _, r := range update.Routes {
|
||||
updateMap[r.ID] = r
|
||||
}
|
||||
|
||||
if len(c.routes) != len(updateMap) {
|
||||
if len(w.routes) != len(updateMap) {
|
||||
isUpdateMapDifferent = true
|
||||
}
|
||||
|
||||
for id, r := range c.routes {
|
||||
for id, r := range w.routes {
|
||||
_, found := updateMap[id]
|
||||
if !found {
|
||||
close(c.routePeersNotifiers[r.Peer])
|
||||
delete(c.routePeersNotifiers, r.Peer)
|
||||
close(w.routePeersNotifiers[r.Peer])
|
||||
delete(w.routePeersNotifiers, r.Peer)
|
||||
isUpdateMapDifferent = true
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
|
||||
if !reflect.DeepEqual(w.routes[id], updateMap[id]) {
|
||||
isUpdateMapDifferent = true
|
||||
}
|
||||
}
|
||||
|
||||
c.routes = updateMap
|
||||
w.routes = updateMap
|
||||
return isUpdateMapDifferent
|
||||
}
|
||||
|
||||
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
|
||||
// Start is the main point of reacting on client network routing events.
|
||||
// All the processing related to the client network should be done here. Thread-safe.
|
||||
func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
func (w *Watcher) Start() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
||||
if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil {
|
||||
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
||||
}
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-c.peerStateUpdate:
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
case <-w.peerStateUpdate:
|
||||
if err := w.recalculateRoutes(reasonPeerUpdate); err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", w.handler, err)
|
||||
}
|
||||
case update := <-c.routeUpdate:
|
||||
if update.updateSerial < c.updateSerial {
|
||||
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
|
||||
case update := <-w.routeUpdate:
|
||||
if update.UpdateSerial < w.updateSerial {
|
||||
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", w.updateSerial, update.UpdateSerial)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("Received a new client network route update for [%v]", c.handler)
|
||||
|
||||
// hash update somehow
|
||||
isTrueRouteUpdate := c.handleUpdate(update)
|
||||
|
||||
c.updateSerial = update.updateSerial
|
||||
|
||||
if isTrueRouteUpdate {
|
||||
log.Debug("Client network update contains different routes, recalculating routes")
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
} else {
|
||||
log.Debug("Route update is not different, skipping route recalculation")
|
||||
}
|
||||
|
||||
c.startPeersStatusChangeWatcher()
|
||||
w.handleRouteUpdate(update)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFromRoute(
|
||||
func (w *Watcher) handleRouteUpdate(update RoutesUpdate) {
|
||||
log.Debugf("Received a new client network route update for [%v]", w.handler)
|
||||
|
||||
// hash update somehow
|
||||
isTrueRouteUpdate := w.classifyUpdate(update)
|
||||
|
||||
w.updateSerial = update.UpdateSerial
|
||||
|
||||
if isTrueRouteUpdate {
|
||||
log.Debugf("client network update %v for [%v] contains different routes, recalculating routes", update.UpdateSerial, w.handler)
|
||||
if err := w.recalculateRoutes(reasonRouteUpdate); err != nil {
|
||||
log.Errorf("failed to recalculate routes for network [%v]: %v", w.handler, err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("route update %v for [%v] is not different, skipping route recalculation", update.UpdateSerial, w.handler)
|
||||
}
|
||||
|
||||
w.startNewPeerStatusWatchers()
|
||||
}
|
||||
|
||||
// Stop stops the watcher and cleans up resources.
|
||||
func (w *Watcher) Stop() {
|
||||
log.Debugf("Stopping watcher for network [%v]", w.handler)
|
||||
|
||||
w.cancel()
|
||||
|
||||
if w.currentChosen == nil {
|
||||
return
|
||||
}
|
||||
if err := w.removeAllowedIPs(w.currentChosen, reasonShutdown); err != nil {
|
||||
log.Errorf("Failed to remove routes for [%v]: %v", w.handler, err)
|
||||
}
|
||||
}
|
||||
|
||||
func HandlerFromRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
@@ -1,4 +1,4 @@
|
||||
package routemanager
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -395,7 +395,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
}
|
||||
|
||||
// create new clientNetwork
|
||||
client := &clientNetwork{
|
||||
client := &Watcher{
|
||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
||||
routes: tc.existingRoutes,
|
||||
currentChosen: currentRoute,
|
||||
@@ -229,15 +229,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
}
|
||||
|
||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||
origPattern := ""
|
||||
var origPattern domain.Domain
|
||||
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||
origPattern = writer.GetOrigPattern()
|
||||
}
|
||||
|
||||
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
|
||||
|
||||
// already punycode via RegisterHandler()
|
||||
originalDomain := domain.Domain(origPattern)
|
||||
originalDomain := origPattern
|
||||
if originalDomain == "" {
|
||||
originalDomain = resolvedDomain
|
||||
}
|
||||
@@ -264,7 +263,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
||||
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
||||
newPrefixes = append(newPrefixes, prefix)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
@@ -21,9 +23,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/server"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
@@ -68,9 +72,9 @@ type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
clientNetworks map[route.HAUniqueID]*client.Watcher
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter *serverRouter
|
||||
serverRouter *server.Router
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
relayMgr *relayClient.Manager
|
||||
@@ -88,6 +92,7 @@ type DefaultManager struct {
|
||||
useNewDNSRoute bool
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||
}
|
||||
|
||||
func NewManager(config ManagerConfig) *DefaultManager {
|
||||
@@ -99,7 +104,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
dnsRouteInterval: config.DNSRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
clientNetworks: make(map[route.HAUniqueID]*client.Watcher),
|
||||
relayMgr: config.RelayManager,
|
||||
sysOps: sysOps,
|
||||
statusRecorder: config.StatusRecorder,
|
||||
@@ -111,6 +116,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
peerStore: config.PeerStore,
|
||||
disableClientRoutes: config.DisableClientRoutes,
|
||||
disableServerRoutes: config.DisableServerRoutes,
|
||||
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
|
||||
}
|
||||
|
||||
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
|
||||
@@ -226,7 +232,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
}
|
||||
|
||||
var err error
|
||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -237,7 +243,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
m.stop()
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.cleanUp()
|
||||
m.serverRouter.CleanUp()
|
||||
}
|
||||
|
||||
if m.routeRefCounter != nil {
|
||||
@@ -265,6 +271,54 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||
func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
toAdd := make(map[route.HAUniqueID]*route.Route)
|
||||
toRemove := make(map[route.HAUniqueID]client.RouteHandler)
|
||||
|
||||
for id, routes := range newRoutes {
|
||||
if len(routes) > 0 {
|
||||
toAdd[id] = routes[0]
|
||||
}
|
||||
}
|
||||
|
||||
for id, activeHandler := range m.activeRoutes {
|
||||
if _, exists := toAdd[id]; exists {
|
||||
delete(toAdd, id)
|
||||
} else {
|
||||
toRemove[id] = activeHandler
|
||||
}
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for id, handler := range toRemove {
|
||||
if err := handler.RemoveRoute(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
|
||||
}
|
||||
delete(m.activeRoutes, id)
|
||||
}
|
||||
|
||||
for id, route := range toAdd {
|
||||
handler := client.HandlerFromRoute(
|
||||
route,
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsRouteInterval,
|
||||
m.statusRecorder,
|
||||
m.wgInterface,
|
||||
m.dnsServer,
|
||||
m.peerStore,
|
||||
m.useNewDNSRoute,
|
||||
)
|
||||
if err := handler.AddRoute(m.ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
||||
continue
|
||||
}
|
||||
m.activeRoutes[id] = handler
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
@@ -279,22 +333,28 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||
|
||||
var merr *multierror.Error
|
||||
if !m.disableClientRoutes {
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
|
||||
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
|
||||
}
|
||||
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||
}
|
||||
m.clientRoutes = newClientRoutesIDMap
|
||||
|
||||
if m.serverRouter == nil {
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
|
||||
return fmt.Errorf("update routes: %w", err)
|
||||
if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||
@@ -341,6 +401,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
|
||||
m.notifier.OnNewRoutes(networks)
|
||||
|
||||
if err := m.updateSystemRoutes(networks); err != nil {
|
||||
log.Errorf("failed to update system routes during selection: %v", err)
|
||||
}
|
||||
|
||||
m.stopObsoleteClients(networks)
|
||||
|
||||
for id, routes := range networks {
|
||||
@@ -349,21 +413,24 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(
|
||||
m.ctx,
|
||||
m.dnsRouteInterval,
|
||||
m.wgInterface,
|
||||
m.statusRecorder,
|
||||
routes[0],
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsServer,
|
||||
m.peerStore,
|
||||
m.useNewDNSRoute,
|
||||
)
|
||||
handler := m.activeRoutes[id]
|
||||
if handler == nil {
|
||||
log.Warnf("no active handler found for route %s", id)
|
||||
continue
|
||||
}
|
||||
|
||||
config := client.WatcherConfig{
|
||||
Context: m.ctx,
|
||||
DNSRouteInterval: m.dnsRouteInterval,
|
||||
WGInterface: m.wgInterface,
|
||||
StatusRecorder: m.statusRecorder,
|
||||
Route: routes[0],
|
||||
Handler: handler,
|
||||
}
|
||||
clientNetworkWatcher := client.NewWatcher(config)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
go clientNetworkWatcher.Start()
|
||||
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
||||
@@ -375,8 +442,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
||||
for id, client := range m.clientNetworks {
|
||||
if _, ok := networks[id]; !ok {
|
||||
log.Debugf("Stopping client network watcher, %s", id)
|
||||
client.cancel()
|
||||
client.Stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
@@ -389,26 +455,29 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(
|
||||
m.ctx,
|
||||
m.dnsRouteInterval,
|
||||
m.wgInterface,
|
||||
m.statusRecorder,
|
||||
routes[0],
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsServer,
|
||||
m.peerStore,
|
||||
m.useNewDNSRoute,
|
||||
)
|
||||
handler := m.activeRoutes[id]
|
||||
if handler == nil {
|
||||
log.Errorf("No active handler found for route %s", id)
|
||||
continue
|
||||
}
|
||||
|
||||
config := client.WatcherConfig{
|
||||
Context: m.ctx,
|
||||
DNSRouteInterval: m.dnsRouteInterval,
|
||||
WGInterface: m.wgInterface,
|
||||
StatusRecorder: m.statusRecorder,
|
||||
Route: routes[0],
|
||||
Handler: handler,
|
||||
}
|
||||
clientNetworkWatcher = client.NewWatcher(config)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
go clientNetworkWatcher.Start()
|
||||
}
|
||||
update := routesUpdate{
|
||||
updateSerial: updateSerial,
|
||||
routes: routes,
|
||||
update := client.RoutesUpdate{
|
||||
UpdateSerial: updateSerial,
|
||||
Routes: routes,
|
||||
}
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||
clientNetworkWatcher.SendUpdate(update)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
@@ -45,7 +44,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -72,7 +71,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.252.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.252.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -100,7 +99,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.30.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -128,7 +127,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.30.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -212,7 +211,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -234,7 +233,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -251,7 +250,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -273,7 +272,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -283,7 +282,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "b",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey2,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -300,7 +299,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -328,7 +327,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -357,7 +356,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "l1",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -377,7 +376,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "r1",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@@ -441,11 +440,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(testCase.inputInitRoutes) > 0 {
|
||||
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
|
||||
require.NoError(t, err, "should update routes with init routes")
|
||||
}
|
||||
|
||||
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
@@ -454,8 +453,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
}
|
||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||
|
||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
if routeManager.serverRouter != nil {
|
||||
require.Equal(t, testCase.serverRoutesExpected, routeManager.serverRouter.RoutesCount(), "server networks size should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package routemanager
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
type Router struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
routes map[route.ID]*route.Route
|
||||
@@ -23,8 +23,8 @@ type serverRouter struct {
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||
return &serverRouter{
|
||||
func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*Router, error) {
|
||||
return &Router{
|
||||
ctx: ctx,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
firewall: firewall,
|
||||
@@ -33,104 +33,110 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
|
||||
serverRoutesToRemove := make([]route.ID, 0)
|
||||
|
||||
for routeID := range m.routes {
|
||||
for routeID := range r.routes {
|
||||
update, found := routesMap[routeID]
|
||||
if !found || !update.Equal(m.routes[routeID]) {
|
||||
if !found || !update.Equal(r.routes[routeID]) {
|
||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range serverRoutesToRemove {
|
||||
oldRoute := m.routes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
oldRoute := r.routes[routeID]
|
||||
err := r.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.routes, routeID)
|
||||
delete(r.routes, routeID)
|
||||
}
|
||||
|
||||
// If routing is to be disabled, do it after routes have been removed
|
||||
// If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled
|
||||
if len(routesMap) > 0 {
|
||||
if err := m.firewall.EnableRouting(); err != nil {
|
||||
if err := r.firewall.EnableRouting(); err != nil {
|
||||
return fmt.Errorf("enable routing: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.firewall.DisableRouting(); err != nil {
|
||||
if err := r.firewall.DisableRouting(); err != nil {
|
||||
return fmt.Errorf("disable routing: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.routes[id]
|
||||
_, found := r.routes[id]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.addToServerNetwork(newRoute, useNewDNSRoute)
|
||||
err := r.addToServerNetwork(newRoute, useNewDNSRoute)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.routes[id] = newRoute
|
||||
r.routes[id] = newRoute
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
if m.ctx.Err() != nil {
|
||||
func (r *Router) removeFromServerNetwork(route *route.Route) error {
|
||||
if r.ctx.Err() != nil {
|
||||
log.Infof("Not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
return r.ctx.Err()
|
||||
}
|
||||
|
||||
routerPair := routeToRouterPair(route, false)
|
||||
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
|
||||
if err := r.firewall.RemoveNatRule(routerPair); err != nil {
|
||||
return fmt.Errorf("remove routing rules: %w", err)
|
||||
}
|
||||
|
||||
delete(m.routes, route.ID)
|
||||
m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
|
||||
delete(r.routes, route.ID)
|
||||
r.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
|
||||
if m.ctx.Err() != nil {
|
||||
func (r *Router) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
|
||||
if r.ctx.Err() != nil {
|
||||
log.Infof("Not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
return r.ctx.Err()
|
||||
}
|
||||
|
||||
routerPair := routeToRouterPair(route, useNewDNSRoute)
|
||||
if err := m.firewall.AddNatRule(routerPair); err != nil {
|
||||
if err := r.firewall.AddNatRule(routerPair); err != nil {
|
||||
return fmt.Errorf("insert routing rules: %w", err)
|
||||
}
|
||||
|
||||
m.routes[route.ID] = route
|
||||
m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
|
||||
r.routes[route.ID] = route
|
||||
r.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serverRouter) cleanUp() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
func (r *Router) CleanUp() {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
|
||||
for _, r := range m.routes {
|
||||
routerPair := routeToRouterPair(r, false)
|
||||
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
|
||||
for _, route := range r.routes {
|
||||
routerPair := routeToRouterPair(route, false)
|
||||
if err := r.firewall.RemoveNatRule(routerPair); err != nil {
|
||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.statusRecorder.CleanLocalPeerStateRoutes()
|
||||
r.statusRecorder.CleanLocalPeerStateRoutes()
|
||||
}
|
||||
|
||||
func (r *Router) RoutesCount() int {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
return len(r.routes)
|
||||
}
|
||||
|
||||
func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair {
|
||||
@@ -24,19 +24,22 @@ func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allo
|
||||
}
|
||||
}
|
||||
|
||||
// Route route methods
|
||||
func (r *Route) String() string {
|
||||
return r.route.Network.String()
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(context.Context) error {
|
||||
_, err := r.routeRefCounter.Increment(r.route.Network, struct{}{})
|
||||
return err
|
||||
if _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) RemoveRoute() error {
|
||||
_, err := r.routeRefCounter.Decrement(r.route.Network)
|
||||
return err
|
||||
if _, err := r.routeRefCounter.Decrement(r.route.Network); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
@@ -52,6 +55,8 @@ func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
|
||||
return err
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(r.route.Network); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,8 +22,13 @@ const (
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
type iface interface {
|
||||
Address() wgaddr.Address
|
||||
Name() string
|
||||
}
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface iface.WGIface) (map[string]int, error) {
|
||||
func Setup(wgIface iface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
|
||||
@@ -6,9 +6,10 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
type Nexthop struct {
|
||||
@@ -30,11 +31,16 @@ func (n Nexthop) String() string {
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
|
||||
type wgIface interface {
|
||||
Address() wgaddr.Address
|
||||
Name() string
|
||||
}
|
||||
|
||||
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface iface.WGIface
|
||||
wgInterface wgIface
|
||||
// prefixes is tracking all the current added prefixes im memory
|
||||
// (this is used in iOS as all route updates require a full table update)
|
||||
//nolint
|
||||
@@ -45,9 +51,27 @@ type SysOps struct {
|
||||
notifier *notifier.Notifier
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
|
||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
notifier: notifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
|
||||
switch {
|
||||
case
|
||||
!addr.IsValid(),
|
||||
addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsMulticast(),
|
||||
addr.IsUnspecified() && prefix.Bits() != 0,
|
||||
r.wgInterface.Address().Network.Contains(addr):
|
||||
return vars.ErrRouteNotAllowed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -33,7 +35,12 @@ func init() {
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
intf := &net.Interface{Name: "lo0"}
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := NewSysOps(nil, nil)
|
||||
|
||||
@@ -43,7 +50,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@@ -59,7 +66,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@@ -119,18 +126,39 @@ func TestBits(t *testing.T) {
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := NewSysOps(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return "lo0"
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
@@ -176,12 +204,40 @@ func fetchOriginalGateway() (net.IP, error) {
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
@@ -106,59 +105,15 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: fix: for default our wg address now appears as the default gw
|
||||
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if prefix.Addr().Is6() {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
if !prefix.Contains(nexthop.IP) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
|
||||
if nexthop.IP.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
|
||||
}
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
nexthop, err = GetNextHop(nexthop.IP)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
|
||||
return r.addToRouteTable(gatewayPrefix, nexthop)
|
||||
}
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return Nexthop{}, err
|
||||
}
|
||||
|
||||
addr := prefix.Addr()
|
||||
if addr.IsUnspecified() {
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
@@ -179,10 +134,7 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface
|
||||
Intf: nexthop.Intf,
|
||||
}
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
vpnAddr := vpnIntf.Address().IP
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||
@@ -271,32 +223,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
|
||||
return r.addToRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
@@ -408,12 +335,8 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
return Nexthop{Intf: intf}, nil
|
||||
}
|
||||
|
||||
if preferredSrc == nil {
|
||||
return Nexthop{}, vars.ErrRouteNotFound
|
||||
return Nexthop{Intf: intf}, nil
|
||||
}
|
||||
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
@@ -457,32 +380,6 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||
localRoutes, err := hasSeparateRouting()
|
||||
|
||||
@@ -3,23 +3,25 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
@@ -27,105 +29,370 @@ type dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func TestAddRemoveRoutes(t *testing.T) {
|
||||
func TestAddVPNRoute(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
shouldRouteToWireguard bool
|
||||
shouldBeRemoved bool
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove Route 100.66.120.0/24",
|
||||
prefix: netip.MustParsePrefix("100.66.120.0/24"),
|
||||
shouldRouteToWireguard: true,
|
||||
shouldBeRemoved: true,
|
||||
name: "IPv4 - Private network route",
|
||||
prefix: netip.MustParsePrefix("10.10.100.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Or Remove Route 127.0.0.1/32",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
shouldRouteToWireguard: false,
|
||||
shouldBeRemoved: false,
|
||||
name: "IPv4 Single host",
|
||||
prefix: netip.MustParsePrefix("10.111.111.111/32"),
|
||||
},
|
||||
{
|
||||
name: "IPv4 RFC3927 test range",
|
||||
prefix: netip.MustParsePrefix("198.51.100.0/24"),
|
||||
},
|
||||
{
|
||||
name: "IPv4 Default route",
|
||||
prefix: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
|
||||
{
|
||||
name: "IPv6 Subnet",
|
||||
prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Single host",
|
||||
prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Default route",
|
||||
prefix: netip.MustParsePrefix("::/0"),
|
||||
},
|
||||
|
||||
// IPv4 addresses that should be rejected (matches validateRoute logic)
|
||||
{
|
||||
name: "IPv4 Loopback",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Link-local unicast",
|
||||
prefix: netip.MustParsePrefix("169.254.1.1/32"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Link-local multicast",
|
||||
prefix: netip.MustParsePrefix("224.0.0.251/32"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Multicast",
|
||||
prefix: netip.MustParsePrefix("239.255.255.250/32"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Unspecified with prefix",
|
||||
prefix: netip.MustParsePrefix("0.0.0.0/32"),
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// IPv6 addresses that should be rejected (matches validateRoute logic)
|
||||
{
|
||||
name: "IPv6 Loopback",
|
||||
prefix: netip.MustParsePrefix("::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Link-local unicast",
|
||||
prefix: netip.MustParsePrefix("fe80::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Link-local multicast",
|
||||
prefix: netip.MustParsePrefix("ff02::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Interface-local multicast",
|
||||
prefix: netip.MustParsePrefix("ff01::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Multicast",
|
||||
prefix: netip.MustParsePrefix("ff00::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Unspecified with prefix",
|
||||
prefix: netip.MustParsePrefix("::/128"),
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "IPv4 WireGuard interface network overlap",
|
||||
prefix: netip.MustParsePrefix("100.65.75.0/24"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 WireGuard interface network subnet",
|
||||
prefix: netip.MustParsePrefix("100.65.75.0/32"),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping ", testCase.name, " on freebsd")
|
||||
}
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun53%d", n),
|
||||
Address: "100.65.75.2/24",
|
||||
WGPrivKey: peerPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(opts)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
|
||||
_, _, err = r.SetupRouting(nil, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
// add the route
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||
if testCase.expectError {
|
||||
assert.ErrorIs(t, err, vars.ErrRouteNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if testCase.shouldRouteToWireguard {
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
||||
// validate it's pointing to the WireGuard interface
|
||||
require.NoError(t, err)
|
||||
|
||||
nextHop := getNextHop(t, testCase.prefix.Addr())
|
||||
assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface")
|
||||
|
||||
// remove route again
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// validate it's gone
|
||||
nextHop, err = GetNextHop(testCase.prefix.Addr())
|
||||
require.True(t,
|
||||
errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(),
|
||||
"err: %v, next hop: %v", err, nextHop)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getNextHop(t *testing.T, addr netip.Addr) Nexthop {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
|
||||
nextHop, err := GetNextHop(addr)
|
||||
|
||||
if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() {
|
||||
// TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is
|
||||
// present in the route table.
|
||||
t.Skip("Skipping windows test")
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr)
|
||||
|
||||
return nextHop
|
||||
}
|
||||
// GetNextHop for bsd is buggy and returns the wrong interface for the default route.
|
||||
|
||||
if addr.IsUnspecified() {
|
||||
// On macOS, querying 0.0.0.0 returns the wrong interface
|
||||
if addr.Is4() {
|
||||
addr = netip.MustParseAddr("1.2.3.4")
|
||||
} else {
|
||||
addr = netip.MustParseAddr("2001:db8::1")
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.Command("route", "-n", "get", addr.String())
|
||||
if addr.Is6() {
|
||||
cmd = exec.Command("route", "-n", "get", "-inet6", addr.String())
|
||||
}
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
t.Logf("route output: %s", output)
|
||||
require.NoError(t, err, "%s failed")
|
||||
|
||||
lines := strings.Split(string(output), "\n")
|
||||
var intf string
|
||||
var gateway string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "interface:") {
|
||||
intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
|
||||
} else if strings.HasPrefix(line, "gateway:") {
|
||||
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
|
||||
}
|
||||
}
|
||||
|
||||
require.NotEmpty(t, intf, "interface should be found in route output")
|
||||
|
||||
iface, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "interface %s should exist", intf)
|
||||
|
||||
nexthop := Nexthop{Intf: iface}
|
||||
|
||||
if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) {
|
||||
addr, err := netip.ParseAddr(gateway)
|
||||
if err == nil {
|
||||
nexthop.IP = addr
|
||||
}
|
||||
}
|
||||
|
||||
return nexthop
|
||||
}
|
||||
|
||||
func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
expectError bool
|
||||
errorType error
|
||||
}{
|
||||
{
|
||||
name: "IPv4 RFC3927 test range",
|
||||
prefix: netip.MustParsePrefix("198.51.100.0/24"),
|
||||
},
|
||||
{
|
||||
name: "IPv4 Single host",
|
||||
prefix: netip.MustParsePrefix("8.8.8.8/32"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 External network route",
|
||||
prefix: netip.MustParsePrefix("2001:db8:1000::/48"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Single host",
|
||||
prefix: netip.MustParsePrefix("2001:db8::1/128"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Subnet",
|
||||
prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Single host",
|
||||
prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"),
|
||||
},
|
||||
|
||||
// Addresses that should be rejected
|
||||
{
|
||||
name: "IPv4 Loopback",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Link-local unicast",
|
||||
prefix: netip.MustParsePrefix("169.254.1.1/32"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Multicast",
|
||||
prefix: netip.MustParsePrefix("239.255.255.250/32"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Unspecified",
|
||||
prefix: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Loopback",
|
||||
prefix: netip.MustParsePrefix("::1/128"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Link-local unicast",
|
||||
prefix: netip.MustParsePrefix("fe80::1/128"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Multicast",
|
||||
prefix: netip.MustParsePrefix("ff00::1/128"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Unspecified",
|
||||
prefix: netip.MustParsePrefix("::/0"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv4 WireGuard interface network overlap",
|
||||
prefix: netip.MustParsePrefix("100.65.75.0/24"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
require.NoError(t, err, "Should be able to get IPv4 default route")
|
||||
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
|
||||
|
||||
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if testCase.prefix.Addr().Is6() &&
|
||||
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
|
||||
t.Skip("Skipping test as no ipv6 default route is available")
|
||||
}
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
t.Fatalf("Failed to get IPv6 default route: %v", err)
|
||||
}
|
||||
|
||||
var initialNextHop Nexthop
|
||||
if testCase.prefix.Addr().Is6() {
|
||||
initialNextHop = initialNextHopV6
|
||||
} else {
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
|
||||
initialNextHop = initialNextHopV4
|
||||
}
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||
|
||||
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop)
|
||||
|
||||
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.shouldBeRemoved {
|
||||
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
|
||||
} else {
|
||||
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
|
||||
}
|
||||
if testCase.expectError {
|
||||
require.ErrorIs(t, err, vars.ErrRouteNotAllowed)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Logf("Next hop for %s: %s", testCase.prefix, nexthop)
|
||||
|
||||
// Verify the route was added and points to non-VPN interface
|
||||
currentNextHop, err := GetNextHop(testCase.prefix.Addr())
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface")
|
||||
|
||||
err = r.removeFromRouteTable(testCase.prefix, nexthop)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNextHop(t *testing.T) {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping on freebsd")
|
||||
}
|
||||
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
if !nexthop.IP.IsValid() {
|
||||
if !defaultNh.IP.IsValid() {
|
||||
t.Fatal("should return a gateway")
|
||||
}
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
@@ -133,7 +400,6 @@ func TestGetNextHop(t *testing.T) {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var testingIP string
|
||||
var testingPrefix netip.Prefix
|
||||
for _, address := range addresses {
|
||||
if address.Network() != "ip+net" {
|
||||
@@ -141,213 +407,23 @@ func TestGetNextHop(t *testing.T) {
|
||||
}
|
||||
prefix := netip.MustParsePrefix(address.String())
|
||||
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
|
||||
testingIP = prefix.Addr().String()
|
||||
testingPrefix = prefix.Masked()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
localIP, err := GetNextHop(testingPrefix.Addr())
|
||||
nh, err := GetNextHop(testingPrefix.Addr())
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error: ", err)
|
||||
}
|
||||
if !localIP.IP.IsValid() {
|
||||
if nh.Intf == nil {
|
||||
t.Fatal("should return a gateway for local network")
|
||||
}
|
||||
if localIP.IP.String() == nexthop.IP.String() {
|
||||
t.Fatal("local IP should not match with gateway IP")
|
||||
if nh.IP.String() == defaultNh.IP.String() {
|
||||
t.Fatal("next hop IP should not match with default gateway IP")
|
||||
}
|
||||
if localIP.IP.String() != testingIP {
|
||||
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultNexthop: ", defaultNexthop)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
preExistingPrefix netip.Prefix
|
||||
shouldAddRoute bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove random Route",
|
||||
prefix: netip.MustParsePrefix("99.99.99.99/32"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if overlaps with default gateway",
|
||||
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if bigger network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if smaller network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if same network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun53%d", n),
|
||||
Address: "100.65.75.2/24",
|
||||
WGPort: 33100,
|
||||
WGPrivKey: peerPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(opts)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
// test if route exists after adding
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "should not return err")
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
// route should either not have been added or should have been removed
|
||||
// In case of already existing route, it should not have been added (but still exist)
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
t.Log("Buffer string: ", buf.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
if !strings.Contains(buf.String(), "because it already exists") {
|
||||
require.False(t, ok, "route should not exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubRange(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var subRangeAddressPrefixes []netip.Prefix
|
||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range subRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if !isSubRangePrefix {
|
||||
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range nonSubRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if isSubRangePrefix {
|
||||
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistsInRouteTable(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var addressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
|
||||
switch {
|
||||
case p.Addr().Is6():
|
||||
continue
|
||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
|
||||
continue
|
||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
|
||||
continue
|
||||
// FreeBSD loopback 127/8 is not added to the routing table
|
||||
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
|
||||
continue
|
||||
default:
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range addressPrefixes {
|
||||
exists, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("address %s should exist in route table", prefix)
|
||||
}
|
||||
if nh.Intf.Name != defaultNh.Intf.Name {
|
||||
t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -384,11 +460,16 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
err := r.AddVPNRoute(prefix, intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
if err := r.AddVPNRoute(prefix, intf); err != nil {
|
||||
if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) {
|
||||
t.Fatalf("addVPNRoute should not return err: %v", err)
|
||||
}
|
||||
t.Logf("addVPNRoute %v returned: %v", prefix, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
err = r.RemoveVPNRoute(prefix, intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) {
|
||||
t.Fatalf("removeVPNRoute should not return err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -422,28 +503,10 @@ func setupTestEnv(t *testing.T) {
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
|
||||
// unique route in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
}
|
||||
|
||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
|
||||
return
|
||||
}
|
||||
|
||||
prefixNexthop, err := GetNextHop(prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
if invert {
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
|
||||
} else {
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -149,6 +149,10 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !nbnet.AdvancedRouting() {
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
@@ -172,6 +176,10 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !nbnet.AdvancedRouting() {
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
@@ -219,7 +227,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
|
||||
ones, _ := route.Dst.Mask.Size()
|
||||
|
||||
prefix := netip.PrefixFrom(addr, ones)
|
||||
prefix := netip.PrefixFrom(addr.Unmap(), ones)
|
||||
if prefix.IsValid() {
|
||||
prefixList = append(prefixList, prefix)
|
||||
}
|
||||
@@ -247,7 +255,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
|
||||
return fmt.Errorf("netlink add route: %w", err)
|
||||
}
|
||||
|
||||
@@ -270,7 +278,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
|
||||
return fmt.Errorf("netlink add unreachable route: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
var expectedLoopbackInt = "lo"
|
||||
var expectedExternalInt = "dummyext0"
|
||||
var expectedInternalInt = "dummyint0"
|
||||
|
||||
@@ -31,12 +30,6 @@ func init() {
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To more specific route (local) without custom dialer via physical interface",
|
||||
expectedInterface: expectedLoopbackInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,10 +11,16 @@ import (
|
||||
)
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
|
||||
268
client/internal/routemanager/systemops/systemops_test.go
Normal file
268
client/internal/routemanager/systemops/systemops_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
type mockWGIface struct {
|
||||
address wgaddr.Address
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockWGIface) Address() wgaddr.Address {
|
||||
return m.address
|
||||
}
|
||||
|
||||
func (m *mockWGIface) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func TestSysOps_validateRoute(t *testing.T) {
|
||||
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
|
||||
mockWG := &mockWGIface{
|
||||
address: wgaddr.Address{
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "wg0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
wgInterface: mockWG,
|
||||
notifier: ¬ifier.Notifier{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
expectError bool
|
||||
}{
|
||||
// Valid routes
|
||||
{
|
||||
name: "valid IPv4 route",
|
||||
prefix: "192.168.1.0/24",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid IPv6 route",
|
||||
prefix: "2001:db8::/32",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid single IPv4 host",
|
||||
prefix: "8.8.8.8/32",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid single IPv6 host",
|
||||
prefix: "2001:4860:4860::8888/128",
|
||||
expectError: false,
|
||||
},
|
||||
|
||||
// Invalid routes - loopback
|
||||
{
|
||||
name: "IPv4 loopback",
|
||||
prefix: "127.0.0.1/32",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
prefix: "::1/128",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - link-local unicast
|
||||
{
|
||||
name: "IPv4 link-local unicast",
|
||||
prefix: "169.254.1.1/32",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 link-local unicast",
|
||||
prefix: "fe80::1/128",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - multicast
|
||||
{
|
||||
name: "IPv4 multicast",
|
||||
prefix: "224.0.0.1/32",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 multicast",
|
||||
prefix: "ff02::1/128",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - link-local multicast
|
||||
{
|
||||
name: "IPv4 link-local multicast",
|
||||
prefix: "224.0.0.0/24",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 link-local multicast",
|
||||
prefix: "ff02::/16",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - interface-local multicast (IPv6 only)
|
||||
{
|
||||
name: "IPv6 interface-local multicast",
|
||||
prefix: "ff01::1/128",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - overlaps with WG interface network
|
||||
{
|
||||
name: "overlaps with WG network - exact match",
|
||||
prefix: "10.0.0.0/24",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "overlaps with WG network - subset",
|
||||
prefix: "10.0.0.1/32",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "overlaps with WG network - host in range",
|
||||
prefix: "10.0.0.100/32",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prefix, err := netip.ParsePrefix(tt.prefix)
|
||||
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
|
||||
|
||||
err = sysOps.validateRoute(prefix)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err, "validateRoute() expected error for %s", tt.prefix)
|
||||
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix)
|
||||
} else {
|
||||
assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) {
|
||||
wgNetwork := netip.MustParsePrefix("192.168.100.0/24")
|
||||
mockWG := &mockWGIface{
|
||||
address: wgaddr.Address{
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "wg0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
wgInterface: mockWG,
|
||||
notifier: ¬ifier.Notifier{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "identical subnet",
|
||||
prefix: "192.168.100.0/24",
|
||||
expectError: true,
|
||||
description: "exact same network as WG interface",
|
||||
},
|
||||
{
|
||||
name: "broader subnet containing WG network",
|
||||
prefix: "192.168.0.0/16",
|
||||
expectError: false,
|
||||
description: "broader network that contains WG network should be allowed",
|
||||
},
|
||||
{
|
||||
name: "host within WG network",
|
||||
prefix: "192.168.100.50/32",
|
||||
expectError: true,
|
||||
description: "specific host within WG network",
|
||||
},
|
||||
{
|
||||
name: "subnet within WG network",
|
||||
prefix: "192.168.100.128/25",
|
||||
expectError: true,
|
||||
description: "smaller subnet within WG network",
|
||||
},
|
||||
{
|
||||
name: "adjacent subnet - same /23",
|
||||
prefix: "192.168.101.0/24",
|
||||
expectError: false,
|
||||
description: "adjacent subnet, no overlap",
|
||||
},
|
||||
{
|
||||
name: "adjacent subnet - different /16",
|
||||
prefix: "192.167.100.0/24",
|
||||
expectError: false,
|
||||
description: "different network, no overlap",
|
||||
},
|
||||
{
|
||||
name: "WG network broadcast address",
|
||||
prefix: "192.168.100.255/32",
|
||||
expectError: true,
|
||||
description: "broadcast address of WG network",
|
||||
},
|
||||
{
|
||||
name: "WG network first usable",
|
||||
prefix: "192.168.100.1/32",
|
||||
expectError: true,
|
||||
description: "first usable address in WG network",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prefix, err := netip.ParsePrefix(tt.prefix)
|
||||
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
|
||||
|
||||
err = sysOps.validateRoute(prefix)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description)
|
||||
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description)
|
||||
} else {
|
||||
assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
|
||||
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
|
||||
mockWG := &mockWGIface{
|
||||
address: wgaddr.Address{
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "wt0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
wgInterface: mockWG,
|
||||
notifier: ¬ifier.Notifier{},
|
||||
}
|
||||
|
||||
var invalidPrefix netip.Prefix
|
||||
err := sysOps.validateRoute(invalidPrefix)
|
||||
|
||||
require.Error(t, err, "validateRoute() expected error for invalid prefix")
|
||||
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix")
|
||||
}
|
||||
@@ -3,15 +3,19 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -26,48 +30,16 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("add", prefix, nexthop)
|
||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("delete", prefix, nexthop)
|
||||
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
|
||||
inet := "-inet"
|
||||
if prefix.Addr().Is6() {
|
||||
inet = "-inet6"
|
||||
}
|
||||
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
network = prefix.Addr().String()
|
||||
}
|
||||
|
||||
args := []string{"-n", action, inet, network}
|
||||
if nexthop.IP.IsValid() {
|
||||
args = append(args, nexthop.IP.Unmap().String())
|
||||
} else if nexthop.Intf != nil {
|
||||
args = append(args, "-interface", nexthop.Intf.Name)
|
||||
}
|
||||
|
||||
if err := retryRouteCmd(args); err != nil {
|
||||
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func retryRouteCmd(args []string) error {
|
||||
operation := func() error {
|
||||
out, err := exec.Command("route", args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
// https://github.com/golang/go/issues/45736
|
||||
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
|
||||
return err
|
||||
} else if err != nil {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return nil
|
||||
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
|
||||
if !prefix.IsValid() {
|
||||
return fmt.Errorf("invalid prefix: %s", prefix)
|
||||
}
|
||||
|
||||
expBackOff := backoff.NewExponentialBackOff()
|
||||
@@ -75,9 +47,157 @@ func retryRouteCmd(args []string) error {
|
||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||
|
||||
err := backoff.Retry(operation, expBackOff)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route cmd retry failed: %w", err)
|
||||
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
|
||||
a := "add"
|
||||
if action == unix.RTM_DELETE {
|
||||
a = "remove"
|
||||
}
|
||||
return fmt.Errorf("%s route for %s: %w", a, prefix, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
|
||||
operation := func() error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
|
||||
}
|
||||
|
||||
msgBytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
|
||||
}
|
||||
|
||||
if _, err = unix.Write(fd, msgBytes); err != nil {
|
||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||
}
|
||||
|
||||
respBuf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, respBuf)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return operation
|
||||
}
|
||||
|
||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||
msg = &route.RouteMessage{
|
||||
Type: action,
|
||||
Flags: unix.RTF_UP,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
const numAddrs = unix.RTAX_NETMASK + 1
|
||||
addrs := make([]route.Addr, numAddrs)
|
||||
|
||||
addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err)
|
||||
}
|
||||
|
||||
if prefix.IsSingleIP() {
|
||||
msg.Flags |= unix.RTF_HOST
|
||||
} else {
|
||||
addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build netmask for %s: %w", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
if nexthop.IP.IsValid() {
|
||||
msg.Flags |= unix.RTF_GATEWAY
|
||||
addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err)
|
||||
}
|
||||
} else if nexthop.Intf != nil {
|
||||
msg.Index = nexthop.Intf.Index
|
||||
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
|
||||
Index: nexthop.Intf.Index,
|
||||
Name: nexthop.Intf.Name,
|
||||
}
|
||||
}
|
||||
|
||||
msg.Addrs = addrs
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) parseRouteResponse(buf []byte) error {
|
||||
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||
return nil
|
||||
}
|
||||
|
||||
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
if rtMsg.Errno != 0 {
|
||||
return fmt.Errorf("parse: %d", rtMsg.Errno)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
||||
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
||||
if addr.Is4() {
|
||||
return &route.Inet4Addr{IP: addr.As4()}, nil
|
||||
}
|
||||
|
||||
if addr.Zone() == "" {
|
||||
return &route.Inet6Addr{IP: addr.As16()}, nil
|
||||
}
|
||||
|
||||
var zone int
|
||||
// zone can be either a numeric zone ID or an interface name.
|
||||
if z, err := strconv.Atoi(addr.Zone()); err == nil {
|
||||
zone = z
|
||||
} else {
|
||||
iface, err := net.InterfaceByName(addr.Zone())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err)
|
||||
}
|
||||
zone = iface.Index
|
||||
}
|
||||
return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil
|
||||
}
|
||||
|
||||
func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
|
||||
bits := prefix.Bits()
|
||||
if prefix.Addr().Is4() {
|
||||
m := net.CIDRMask(bits, 32)
|
||||
var maskBytes [4]byte
|
||||
copy(maskBytes[:], m)
|
||||
return &route.Inet4Addr{IP: maskBytes}, nil
|
||||
}
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
m := net.CIDRMask(bits, 128)
|
||||
var maskBytes [16]byte
|
||||
copy(maskBytes[:], m)
|
||||
return &route.Inet6Addr{IP: maskBytes}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user