Compare commits

..

1 Commits

Author SHA1 Message Date
Zoltán Papp
7573dfb354 Before lazy connection, when the peer disconnected, the status switched to disconnected.
After implementing lazy connection, the peer state is connecting, so we did not decrease the reference counters on the routes.
2025-06-01 12:46:04 +02:00
157 changed files with 4128 additions and 7701 deletions

View File

@@ -21,6 +21,7 @@ jobs:
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum
only_warn: 1
golangci:
strategy:
fail-fast: false

View File

@@ -65,13 +65,6 @@ jobs:
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu

View File

@@ -172,11 +172,11 @@ jobs:
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0

View File

@@ -175,11 +175,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
ids:
- netbird
goarch: arm
@@ -192,12 +191,11 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
ids:
- netbird
goarch: amd64
@@ -209,11 +207,9 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids:
- netbird
goarch: arm64
@@ -225,11 +221,9 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
ids:
- netbird
goarch: arm
@@ -242,12 +236,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
@@ -259,11 +251,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
@@ -275,11 +266,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
@@ -292,11 +282,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
ids:
- netbird-signal
goarch: amd64
@@ -308,11 +297,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
ids:
- netbird-signal
goarch: arm64
@@ -324,11 +312,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
ids:
- netbird-signal
goarch: arm
@@ -341,11 +328,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
ids:
- netbird-mgmt
goarch: amd64
@@ -357,11 +343,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
ids:
- netbird-mgmt
goarch: arm64
@@ -373,11 +358,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-arm
ids:
- netbird-mgmt
goarch: arm
@@ -390,11 +374,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
ids:
- netbird-mgmt
goarch: amd64
@@ -406,11 +389,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
ids:
- netbird-mgmt
goarch: arm64
@@ -422,12 +404,11 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
ids:
- netbird-mgmt
goarch: arm
@@ -440,11 +421,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
@@ -456,11 +436,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
@@ -472,11 +451,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
@@ -489,7 +467,7 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
@@ -568,84 +546,6 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/relay:latest
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:latest
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:debug-latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/upload:latest
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://docs.netbird.io/slack-url">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
@@ -29,7 +29,7 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
<br/>
</strong>

View File

@@ -71,42 +71,6 @@ func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// SetRosenpassEnabled store if rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled read rosenpass enabled from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive store the given permissive and wait for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive read rosenpass permissive from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)

View File

@@ -69,22 +69,6 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
return a.ipAnonymizer[ip]
}
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
// Convert IP to netip.Addr
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return addr
}
anonIP := a.AnonymizeIP(ip)
return net.UDPAddr{
IP: anonIP.AsSlice(),
Port: addr.Port,
Zone: addr.Zone,
}
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {

View File

@@ -39,6 +39,7 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
@@ -77,6 +78,7 @@ var (
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
blockLANAccess bool
debugUploadBundle bool
debugUploadBundleURL string
lazyConnEnabled bool

View File

@@ -2,7 +2,6 @@ package cmd
import (
"context"
"runtime"
"sync"
"github.com/kardianos/service"
@@ -28,19 +27,12 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
}
func newSVCConfig() *service.Config {
config := &service.Config{
return &service.Config{
Name: serviceName,
DisplayName: "Netbird",
Description: "Netbird mesh network client",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
Option: make(service.KeyValue),
EnvVars: make(map[string]string),
}
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
}
func newSVC(prg *program, conf *service.Config) (service.Service, error) {

View File

@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
}
if logFile != "" {
if logFile != "console" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
}

View File

@@ -6,8 +6,6 @@ const (
disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
)
var (
@@ -15,8 +13,6 @@ var (
disableServerRoutes bool
disableDNS bool
disableFirewall bool
blockLANAccess bool
blockInbound bool
)
func init() {
@@ -32,11 +28,4 @@ func init() {
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
}

View File

@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,

View File

@@ -55,11 +55,12 @@ func init() {
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
@@ -118,9 +119,83 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err
}
ic, err := setupConfig(customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup config: %v", err)
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
providedSetupKey, err := getSetupKey()
@@ -128,7 +203,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err
}
config, err := internal.UpdateOrCreateConfig(*ic)
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
@@ -187,141 +262,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
providedSetupKey, err := getSetupKey()
if err != nil {
return fmt.Errorf("get setup key: %v", err)
return err
}
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup login request: %v", err)
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
ic.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
return &ic, nil
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
@@ -358,7 +301,7 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
return err
}
loginRequest.InterfaceName = &interfaceName
}
@@ -393,14 +336,49 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.BlockLanAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
loginRequest.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
}
return &loginRequest, nil
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func validateNATExternalIPs(list []string) error {

View File

@@ -147,10 +147,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -202,7 +198,7 @@ func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
firewall.ProtocolALL,
"all",
nil,
nil,
firewall.ActionAccept,
@@ -223,16 +219,10 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
}

View File

@@ -2,7 +2,7 @@ package iptables
import (
"fmt"
"net/netip"
"net"
"testing"
"time"
@@ -19,8 +19,11 @@ var ifaceMock = &iFaceMock{
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -67,12 +70,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@@ -92,9 +95,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := netip.MustParseAddr("10.20.0.3")
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
@@ -116,8 +119,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -138,11 +144,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -180,8 +186,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -203,11 +212,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err)
ip := netip.MustParseAddr("10.20.0.100")
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}

View File

@@ -248,6 +248,10 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
@@ -274,6 +278,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)

View File

@@ -116,8 +116,6 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering(
id []byte,
sources []netip.Prefix,

View File

@@ -170,10 +170,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -328,16 +324,10 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
}

View File

@@ -3,6 +3,7 @@ package nftables
import (
"bytes"
"fmt"
"net"
"net/netip"
"os/exec"
"testing"
@@ -24,8 +25,11 @@ var ifaceMock = &iFaceMock{
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -66,11 +70,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
ip := netip.MustParseAddr("100.96.0.1").Unmap()
ip := net.ParseIP("100.96.0.1")
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@@ -105,6 +109,8 @@ func TestNftablesManager(t *testing.T) {
}
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{
&expr.Payload{
DestRegister: 1,
@@ -126,7 +132,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip.AsSlice(),
Data: add.AsSlice(),
},
&expr.Payload{
DestRegister: 1,
@@ -167,8 +173,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -188,11 +197,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
}()
ip := netip.MustParseAddr("10.20.0.100")
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -273,8 +282,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
})
ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(

View File

@@ -573,6 +573,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -1002,6 +1006,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}

View File

@@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
ip net.IP
netstack bool
}
@@ -71,11 +71,12 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: iface.Address().Network.Bits(),
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
},
}
@@ -115,7 +116,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
ip: iface.Address().IP,
}
receiveWindow := defaultReceiveWindow
@@ -166,7 +167,7 @@ func (f *Forwarder) Stop() {
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
@@ -178,6 +179,7 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
}
}

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
}
var rxPackets, txPackets uint64

View File

@@ -45,26 +45,24 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
if ipv4 := ip.To4(); ipv4 != nil {
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
ipStr := ipv4.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
}
}
}
@@ -81,12 +79,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -104,13 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue
}
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
@@ -124,8 +116,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}()
var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []netip.Addr
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}

View File

@@ -20,8 +20,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost range",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.2"),
expected: true,
@@ -29,8 +32,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.1"),
expected: true,
@@ -38,8 +44,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.255.255.255"),
expected: true,
@@ -47,8 +56,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.1"),
expected: true,
@@ -56,8 +68,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
@@ -65,8 +80,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.33"),
expected: false,
@@ -74,8 +92,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("fe80::1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: netip.MustParseAddr("fe80::1"),
expected: false,

View File

@@ -38,8 +38,11 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}

View File

@@ -39,12 +39,8 @@ const (
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
)
@@ -75,6 +71,7 @@ type Manager struct {
// incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -151,11 +148,6 @@ func parseCreateEnv() (bool, bool) {
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
}
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
}
return disableConntrack, enableLocalForwarding
@@ -277,7 +269,7 @@ func (m *Manager) determineRouting() error {
log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil:
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
@@ -334,10 +326,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
@@ -618,8 +606,9 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true
}
// for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size)
if m.stateful {
m.trackOutbound(d, srcIP, dstIP, size)
}
return false
}
@@ -788,10 +777,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true
}
// If requested we pass local traffic to internal interfaces to the forwarder.
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
return m.handleForwardedLocalTraffic(packetData)
// if running in netstack mode we need to pass this to the forwarder
if m.netstack && m.localForwarding {
return m.handleNetstackLocalTraffic(packetData)
}
// track inbound packets to get the correct direction and session id for flows
@@ -801,7 +789,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return false
}
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
fwd := m.forwarder.Load()
if fwd == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)")
@@ -1099,6 +1088,11 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true
}
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not

View File

@@ -174,6 +174,11 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup
sc.setupFunc(manager)
@@ -214,6 +219,11 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table
srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count)
@@ -257,6 +267,11 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@@ -289,6 +304,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -302,6 +321,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -316,6 +339,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -329,6 +356,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -342,6 +373,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -355,6 +390,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -369,6 +408,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "post_handshake",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -383,6 +426,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -396,6 +443,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -542,6 +593,11 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
@@ -625,6 +681,11 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
@@ -736,6 +797,11 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@@ -816,6 +882,11 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
@@ -961,8 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
}
for _, r := range rules {
dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}

View File

@@ -19,8 +19,12 @@ import (
)
func TestPeerACLFiltering(t *testing.T) {
localIP := netip.MustParseAddr("100.10.0.100")
wgNet := netip.MustParsePrefix("100.10.0.0/16")
localIP := net.ParseIP("100.10.0.100")
wgNet := &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
@@ -39,6 +43,8 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil))
})
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs()
require.NoError(t, err)
@@ -575,13 +581,14 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix(network)
localIP, wgNet, err := net.ParseCIDR(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
IP: localIP,
Network: wgNet,
}
},
@@ -1433,8 +1440,11 @@ func TestRouteACLSet(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}

View File

@@ -271,8 +271,11 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
@@ -282,6 +285,10 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err)
return
}
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
@@ -389,6 +396,10 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() {
@@ -498,6 +509,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{

View File

@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil
}
if u.address.Network.Contains(a) {
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}

View File

@@ -1,17 +0,0 @@
package configurer
import (
"net"
"net/netip"
)
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -5,7 +5,6 @@ package configurer
import (
"fmt"
"net"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
@@ -13,8 +12,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var zeroKey wgtypes.Key
type KernelConfigurer struct {
deviceName string
}
@@ -46,7 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -55,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: prefixesToIPNets(allowedIps),
AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
PresharedKey: preSharedKey,
@@ -92,10 +89,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
return nil
}
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
ipNet := net.IPNet{
IP: allowedIP.Addr().AsSlice(),
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -106,7 +103,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{ipNet},
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
@@ -119,10 +116,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
return nil
}
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
ipNet := net.IPNet{
IP: allowedIP.Addr().AsSlice(),
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return fmt.Errorf("parse allowed IP: %w", err)
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -190,11 +187,7 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
if err != nil {
return err
}
defer func() {
if err := wg.Close(); err != nil {
log.Errorf("Failed to close wgctrl client: %v", err)
}
}()
defer wg.Close()
// validate if device with name exists
_, err = wg.Device(c.deviceName)
@@ -208,47 +201,6 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() {
}
func (c *KernelConfigurer) FullStats() (*Stats, error) {
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
fullStats := &Stats{
DeviceName: wgDevice.Name,
PublicKey: wgDevice.PublicKey.String(),
ListenPort: wgDevice.ListenPort,
FWMark: wgDevice.FirewallMark,
Peers: []Peer{},
}
for _, p := range wgDevice.Peers {
peer := Peer{
PublicKey: p.PublicKey.String(),
AllowedIPs: p.AllowedIPs,
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint
}
fullStats.Peers = append(fullStats.Peers, peer)
}
return fullStats, nil
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()

View File

@@ -5,7 +5,6 @@ import (
"encoding/hex"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"strconv"
@@ -20,17 +19,10 @@ import (
)
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -68,7 +60,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -77,7 +69,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: prefixesToIPNets(allowedIps),
AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
@@ -107,10 +99,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
ipNet := net.IPNet{
IP: allowedIP.Addr().AsSlice(),
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -121,7 +113,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{ipNet},
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
@@ -131,7 +123,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet()
if err != nil {
return err
@@ -154,8 +146,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
foundPeer := false
removedAllowedIP := false
ip := allowedIP.String()
for _, line := range lines {
line = strings.TrimSpace(line)
@@ -178,8 +168,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
// Append the line to the output string
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIPStr)
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
@@ -196,15 +186,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
ipcStr, err := c.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("IpcGet failed: %w", err)
}
return parseStatus(c.deviceName, ipcStr)
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() {
var err error
@@ -384,136 +365,3 @@ func getFwmark() int {
}
return 0
}
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
// Decode hex string to bytes
keyBytes, err := hex.DecodeString(hexKey)
if err != nil {
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
}
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
if len(keyBytes) != 32 {
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
}
// Convert to wgtypes.Key
var key wgtypes.Key
copy(key[:], keyBytes)
return key, nil
}
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
stats := &Stats{DeviceName: deviceName}
var currentPeer *Peer
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
val := parts[1]
switch key {
case privateKey:
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
continue
}
stats.PublicKey = key.PublicKey().String()
case publicKey:
// Save previous peer
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse public key: %v", err)
continue
}
currentPeer = &Peer{
PublicKey: key.String(),
}
case listenPort:
if port, err := strconv.Atoi(val); err == nil {
stats.ListenPort = port
}
case fwmark:
if fwmark, err := strconv.Atoi(val); err == nil {
stats.FWMark = fwmark
}
case endpoint:
if currentPeer == nil {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Errorf("failed to parse endpoint port: %v", err)
continue
}
currentPeer.Endpoint = net.UDPAddr{
IP: net.ParseIP(host),
Port: port,
}
case allowedIP:
if currentPeer == nil {
continue
}
_, ipnet, err := net.ParseCIDR(val)
if err == nil {
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
}
case ipcKeyTxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.TxBytes = rxBytes
case ipcKeyRxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.RxBytes = rxBytes
case ipcKeyLastHandshakeTimeSec:
if currentPeer == nil {
continue
}
ts, err := toLastHandshake(val)
if err != nil {
continue
}
currentPeer.LastHandshake = ts
case presharedKey:
if currentPeer == nil {
continue
}
if val != "" {
currentPeer.PresharedKey = true
}
}
}
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
return stats, nil
}

View File

@@ -1,24 +0,0 @@
package configurer
import (
"net"
"time"
)
type Peer struct {
PublicKey string
Endpoint net.UDPAddr
AllowedIPs []net.IPNet
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
}
type Stats struct {
DeviceName string
PublicKey string
ListenPort int
FWMark int
Peers []Peer
}

View File

@@ -1,6 +1,7 @@
package device
import (
"net"
"net/netip"
"sync"
@@ -23,6 +24,9 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
}
// FilteredDevice to override Read or Write of packets

View File

@@ -51,11 +51,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@@ -2,7 +2,6 @@ package device
import (
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -12,11 +11,10 @@ import (
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
}

View File

@@ -64,15 +64,7 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
}
ip := address.IP.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
mask := "0x" + address.Network.Mask.String()
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@@ -111,14 +111,14 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
}
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional.
// If allowedIps is given it will be added to the existing ones.
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
netIPNets := prefixesToIPNets(allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
@@ -131,7 +131,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
@@ -140,7 +140,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
@@ -185,6 +185,7 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
}
w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter)
return nil
@@ -216,10 +217,6 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)
@@ -254,3 +251,14 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet()
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -5,6 +5,7 @@
package mocks
import (
net "net"
"net/netip"
reflect "reflect"
@@ -89,3 +90,15 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -1,6 +1,8 @@
package netstack
import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
@@ -13,8 +15,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive
address netip.Addr
dnsAddress netip.Addr
address net.IP
dnsAddress net.IP
mtu int
listenAddress string
@@ -22,7 +24,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device
}
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
return &NetStackTun{
address: address,
dnsAddress: dnsAddress,
@@ -32,9 +34,19 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A
}
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{t.address},
[]netip.Addr{t.dnsAddress},
[]netip.Addr{addr.Unmap()},
[]netip.Addr{dnsAddr.Unmap()},
t.mtu)
if err != nil {
return nil, nil, err

View File

@@ -2,27 +2,28 @@ package wgaddr
import (
"fmt"
"net/netip"
"net"
)
// Address WireGuard parsed address
type Address struct {
IP netip.Addr
Network netip.Prefix
IP net.IP
Network *net.IPNet
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) {
prefix, err := netip.ParsePrefix(address)
ip, network, err := net.ParseCIDR(address)
if err != nil {
return Address{}, err
}
return Address{
IP: prefix.Addr().Unmap(),
Network: prefix.Masked(),
IP: ip,
Network: network,
}, nil
}
func (addr Address) String() string {
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}

View File

@@ -58,11 +58,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.mutex.Lock()
defer d.mutex.Unlock()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
start := time.Now()
defer func() {
total := 0
@@ -74,8 +69,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total)
}()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
}
@@ -284,10 +285,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
// TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")

View File

@@ -1,14 +1,13 @@
package acl
import (
"net/netip"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -43,31 +42,35 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
IP: ip,
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
if err != nil {
t.Errorf("create firewall: %v", err)
return
}
defer func(fw manager.Manager) {
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
if fw.IsStateful() {
assert.Equal(t, 0, len(acl.peerRulesPairs))
} else {
assert.Equal(t, 2, len(acl.peerRulesPairs))
if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
return
}
})
@@ -91,13 +94,12 @@ func TestDefaultManager(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
expectedRules := 2
if fw.IsStateful() {
expectedRules = 1 // only the inbound rule
// we should have one old and one new rule in the existed rules
if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied")
return
}
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
// check that old rule was removed
previousCount := 0
for id := range acl.peerRulesPairs {
@@ -105,86 +107,26 @@ func TestDefaultManager(t *testing.T) {
previousCount++
}
}
expectedPreviousCount := 0
if !fw.IsStateful() {
expectedPreviousCount = 1
if previousCount != 1 {
t.Errorf("old rule was not removed")
}
assert.Equal(t, expectedPreviousCount, previousCount)
})
t.Run("handle default rules", func(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, 0, len(acl.peerRulesPairs))
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap, false)
expectedRules := 1
if fw.IsStateful() {
expectedRules = 1 // only inbound allow-all rule
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
}
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
})
}
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
// In stateless mode, we should have both inbound and outbound rules
assert.False(t, fw.IsStateful())
assert.Equal(t, 2, len(acl.peerRulesPairs))
})
}
@@ -250,19 +192,42 @@ func TestDefaultManagerSquashRules(t *testing.T) {
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
if len(rules) != 2 {
t.Errorf("rules should contain 2, got: %v", rules)
return
}
r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
switch {
case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
switch {
case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
}
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
@@ -326,8 +291,9 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) {
@@ -370,29 +336,33 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
IP: ip,
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
if err != nil {
t.Errorf("create firewall: %v", err)
return
}
defer func(fw manager.Manager) {
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false)
expectedRules := 3
if fw.IsStateful() {
expectedRules = 3 // 2 inbound rules + SSH rule
if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return
}
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
}

View File

@@ -68,8 +68,8 @@ type ConfigInput struct {
DisableServerRoutes *bool
DisableDNS *bool
DisableFirewall *bool
BlockLANAccess *bool
BlockInbound *bool
BlockLANAccess *bool
DisableNotifications *bool
@@ -98,8 +98,8 @@ type Config struct {
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
BlockInbound bool
BlockLANAccess bool
DisableNotifications *bool
@@ -483,16 +483,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
if *input.BlockInbound {
log.Infof("blocking inbound connections")
} else {
log.Infof("allowing inbound connections")
}
config.BlockInbound = *input.BlockInbound
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications {
log.Infof("disabling notifications")

View File

@@ -14,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
@@ -34,9 +33,9 @@ type ConnMgr struct {
lazyConnMgr *manager.Manager
wg sync.WaitGroup
lazyCtx context.Context
lazyCtxCancel context.CancelFunc
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
@@ -86,7 +85,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager()
return e.addPeersToLazyConnManager(ctx)
} else {
if e.lazyConnMgr == nil {
return nil
@@ -98,18 +97,8 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
}
}
// UpdateRouteHAMap updates the route HA mappings in the lazy connection manager
func (e *ConnMgr) UpdateRouteHAMap(haMap route.HAMap) {
if !e.isStartedWithLazyMgr() {
log.Debugf("lazy connection manager is not started, skipping UpdateRouteHAMap")
return
}
e.lazyConnMgr.UpdateRouteHAMap(haMap)
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) {
if e.lazyConnMgr == nil {
return
}
@@ -133,7 +122,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
@@ -143,7 +132,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
}
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
if err := peerConn.Open(ctx); err != nil {
if err := peerConn.Open(e.ctx); err != nil {
peerConn.Log.Errorf("failed to open connection: %v", err)
}
}
@@ -221,9 +210,9 @@ func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn,
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(ctx); err != nil {
if err := conn.Open(e.ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
@@ -235,27 +224,29 @@ func (e *ConnMgr) Close() {
return
}
e.lazyCtxCancel()
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
}
func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
ctx, cancel := context.WithCancel(parentCtx)
e.ctx = ctx
e.ctxCancel = cancel
e.wg.Add(1)
go func() {
defer e.wg.Done()
e.lazyConnMgr.Start(e.lazyCtx)
e.lazyConnMgr.Start(ctx)
}()
}
func (e *ConnMgr) addPeersToLazyConnManager() error {
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
peers := e.peerStore.PeersPubKey()
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
for _, peerID := range peers {
@@ -275,7 +266,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {
@@ -283,7 +274,7 @@ func (e *ConnMgr) closeManager(ctx context.Context) {
return
}
e.lazyCtxCancel()
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
@@ -293,7 +284,7 @@ func (e *ConnMgr) closeManager(ctx context.Context) {
}
func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
return e.lazyConnMgr != nil && e.ctxCancel != nil
}
func inactivityThresholdEnv() *time.Duration {

View File

@@ -436,12 +436,11 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
DisableServerRoutes: config.DisableServerRoutes,
DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess,
BlockInbound: config.BlockInbound,
BlockLANAccess: config.BlockLANAccess,
LazyConnectionEnabled: config.LazyConnectionEnabled,
}
@@ -500,9 +499,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -270,21 +270,11 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
}
if err := g.addWgShow(); err != nil {
log.Errorf("Failed to add wg show output: %v", err)
}
if g.logFile != "console" && g.logFile != "" {
if g.logFile != "console" {
if err := g.addLogfile(); err != nil {
log.Errorf("Failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs as fallback: %v", err)
}
return fmt.Errorf("add log file: %w", err)
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs: %v", err)
}
return nil
}
@@ -376,33 +366,17 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
if g.internalConfig.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
if g.internalConfig.DisableNotifications != nil {
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
}
configContent.WriteString(fmt.Sprintf("DNSLabels: %v\n", g.internalConfig.DNSLabels))
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
if g.internalConfig.ClientCertPath != "" {
configContent.WriteString(fmt.Sprintf("ClientCertPath: %s\n", g.internalConfig.ClientCertPath))
}
if g.internalConfig.ClientCertKeyPath != "" {
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes))
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
}

View File

@@ -4,104 +4,17 @@ package debug
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"os"
"os/exec"
"sort"
"strings"
"time"
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
)
const (
maxLogEntries = 100000
maxLogAge = 7 * 24 * time.Hour // Last 7 days
)
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
func (g *BundleGenerator) trySystemdLogFallback() error {
log.Debug("Attempting to collect systemd journal logs")
serviceName := getServiceName()
journalLogs, err := getSystemdLogs(serviceName)
if err != nil {
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
}
if strings.Contains(journalLogs, "No recent log entries found") {
log.Debug("No recent log entries found in systemd journal")
return nil
}
if g.anonymize {
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
}
logReader := strings.NewReader(journalLogs)
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
if err := g.addFileToZip(logReader, fileName); err != nil {
return fmt.Errorf("add systemd logs to bundle: %w", err)
}
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
return nil
}
// getServiceName gets the service name from environment or defaults to netbird
func getServiceName() string {
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
return unitName
}
return "netbird"
}
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
func getSystemdLogs(serviceName string) (string, error) {
args := []string{
"-u", fmt.Sprintf("%s.service", serviceName),
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
"--lines", fmt.Sprintf("%d", maxLogEntries),
"--no-pager",
"--output", "short-iso",
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "journalctl", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
}
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("journalctl command not found: %w", err)
}
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
}
logs := stdout.String()
if strings.TrimSpace(logs) == "" {
return "No recent log entries found in systemd journal", nil
}
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
serviceName, maxLogEntries, maxLogAge.String())
return header + logs, nil
}
// addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules")
@@ -568,7 +481,7 @@ func formatExpr(exp expr.Any) string {
case *expr.Fib:
return formatFib(e)
case *expr.Target:
return fmt.Sprintf("jump %s", e.Name)
return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets
case *expr.Immediate:
if e.Register == 1 {
return formatImmediateData(e.Data)

View File

@@ -6,9 +6,3 @@ package debug
func (g *BundleGenerator) addFirewallRules() error {
return nil
}
func (g *BundleGenerator) trySystemdLogFallback() error {
// Systemd is only available on Linux
// TODO: Add BSD support
return nil
}

View File

@@ -1,66 +0,0 @@
package debug
import (
"bytes"
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGIface interface {
FullStats() (*configurer.Stats, error)
}
func (g *BundleGenerator) addWgShow() error {
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err
}
output := g.toWGShowFormat(result)
reader := bytes.NewReader([]byte(output))
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
return fmt.Errorf("add wg show to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
if s.FWMark != 0 {
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
}
for _, peer := range s.Peers {
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
if peer.Endpoint.IP != nil {
if g.anonymize {
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
} else {
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
}
}
if len(peer.AllowedIPs) > 0 {
var ipStrings []string
for _, ipnet := range peer.AllowedIPs {
ipStrings = append(ipStrings, ipnet.String())
}
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
sb.WriteString(" preshared key: (hidden)\n")
}
}
return sb.String()
}

View File

@@ -2,7 +2,7 @@ package internal
import (
"fmt"
"net/netip"
"net"
"slices"
"strings"
@@ -12,14 +12,13 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip, err := netip.ParseAddr(aRecord.RData)
if err != nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData)
if ip == nil || ip.To4() == nil {
return nbdns.SimpleRecord{}, false
}
if !prefix.Contains(ip) {
if !ipNet.Contains(ip) {
return nbdns.SimpleRecord{}, false
}
@@ -37,19 +36,16 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.Sim
}
// generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := network.Masked().Addr()
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask)
maskOnes, _ := ipNet.Mask.Size()
// round up to nearest byte
octetsToUse := (network.Bits() + 7) / 8
octetsToUse := (maskOnes + 7) / 8
octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
}
reverseOctets := make([]string, octetsToUse)
@@ -72,7 +68,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
}
// collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
@@ -81,7 +77,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
continue
}
if ptrRecord, ok := createPTRRecord(record, prefix); ok {
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
records = append(records, ptrRecord)
}
}
@@ -91,8 +87,8 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
}
// addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, network netip.Prefix) {
zoneName, err := generateReverseZoneName(network)
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
zoneName, err := generateReverseZoneName(ipNet)
if err != nil {
log.Warn(err)
return
@@ -103,7 +99,7 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
return
}
records := collectPTRRecords(config, network)
records := collectPTRRecords(config, ipNet)
reverseZone := nbdns.CustomZone{
Domain: zoneName,

View File

@@ -1,7 +1,6 @@
package dns
import (
"fmt"
"slices"
"strings"
"sync"
@@ -149,42 +148,61 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
log.Tracef("current handlers (%d):", len(handlers))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order
for _, entry := range handlers {
matched := c.isHandlerMatch(qname, entry)
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
var matched bool
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = strings.EqualFold(qname, entry.Pattern)
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
continue
}
return
}
if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue
}
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler")
continue
}
return
}
// No handler matched or all handlers passed
@@ -195,22 +213,3 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Errorf("failed to write DNS response: %v", err)
}
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
return strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
return strings.EqualFold(qname, entry.Pattern)
}
}
}

View File

@@ -1,14 +1,11 @@
package dns
import (
"context"
"errors"
"fmt"
"io"
"os/exec"
"strings"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
@@ -44,20 +41,6 @@ const (
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
registrationEnabledKey = "RegistrationEnabled"
maxNumberOfAddressesToRegisterKey = "MaxNumberOfAddressesToRegister"
// NetBIOS/WINS settings
netbtInterfacePath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces`
netbiosOptionsKey = "NetbiosOptions"
// NetBIOS option values: 0 = from DHCP, 1 = enabled, 2 = disabled
netbiosFromDHCP = 0
netbiosEnabled = 1
netbiosDisabled = 2
// RP_FORCE: Reapply all policies even if no policy change was detected
rpForce = 0x1
)
@@ -84,85 +67,16 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
log.Infof("detected GPO DNS policy configuration, using policy store")
}
configurator := &registryConfigurator{
return &registryConfigurator{
guid: guid,
gpo: useGPO,
}
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
return configurator, nil
}, nil
}
func (r *registryConfigurator) supportCustomPort() bool {
return false
}
func (r *registryConfigurator) configureInterface() error {
var merr *multierror.Error
if err := r.disableDNSRegistrationForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable DNS registration: %w", err))
}
if err := r.disableWINSForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable WINS: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableDNSRegistrationForInterface() error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closer(regKey)
var merr *multierror.Error
if err := regKey.SetDWordValue(disableDynamicUpdateKey, 1); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", disableDynamicUpdateKey, err))
}
if err := regKey.SetDWordValue(registrationEnabledKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", registrationEnabledKey, err))
}
if err := regKey.SetDWordValue(maxNumberOfAddressesToRegisterKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", maxNumberOfAddressesToRegisterKey, err))
}
if merr == nil || len(merr.Errors) == 0 {
log.Infof("disabled DNS registration for interface %s", r.guid)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableWINSForInterface() error {
netbtKeyPath := fmt.Sprintf(`%s\Tcpip_%s`, netbtInterfacePath, r.guid)
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
regKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create NetBT interface key %s: %w", netbtKeyPath, err)
}
}
defer closer(regKey)
// NetbiosOptions: 2 = disabled
if err := regKey.SetDWordValue(netbiosOptionsKey, netbiosDisabled); err != nil {
return fmt.Errorf("set %s: %w", netbiosOptionsKey, err)
}
log.Infof("disabled WINS/NetBIOS for interface %s", r.guid)
return nil
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if config.RouteAll {
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
@@ -205,7 +119,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return fmt.Errorf("update search domains: %w", err)
}
go r.flushDNSCache()
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
}
@@ -275,25 +191,7 @@ func (r *registryConfigurator) string() string {
return "registry"
}
func (r *registryConfigurator) registerDNS() {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
// nolint:misspell
cmd := exec.CommandContext(ctx, "ipconfig", "/registerdns")
out, err := cmd.CombinedOutput()
if err != nil {
log.Errorf("failed to register DNS: %v, output: %s", err, out)
return
}
log.Info("registered DNS names")
}
func (r *registryConfigurator) flushDNSCache() {
r.registerDNS()
func (r *registryConfigurator) flushDNSCache() error {
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() {
if rec := recover(); rec != nil {
@@ -304,14 +202,13 @@ func (r *registryConfigurator) flushDNSCache() {
ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) {
log.Errorf("DnsFlushResolverCache failed: %v", err)
return
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
}
log.Errorf("DnsFlushResolverCache failed")
return
return fmt.Errorf("DnsFlushResolverCache failed")
}
log.Info("flushed DNS cache")
return nil
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
@@ -366,7 +263,9 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err)
}
go r.flushDNSCache()
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
}

View File

@@ -12,19 +12,16 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
)
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
}
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
}
}
@@ -67,12 +64,8 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// Check if we have any records for this domain name with different types
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
// TODO: return success if we have a different record type for the same name, relevant for search domains
replyMessage.Rcode = dns.RcodeNameError
}
if err := w.WriteMsg(replyMessage); err != nil {
@@ -80,15 +73,6 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, exists := d.domains[domainName]
return exists
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock()
@@ -127,7 +111,6 @@ func (d *Resolver) Update(update []nbdns.SimpleRecord) {
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
@@ -161,7 +144,6 @@ func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
}
d.records[q] = append(d.records[q], rr)
d.domains[domain.Domain(q.Name)] = struct{}{}
return nil
}

View File

@@ -470,115 +470,3 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
})
}
}
// TestLocalResolver_NoErrorWithDifferentRecordType verifies that querying for a record type
// that doesn't exist but where other record types exist for the same domain returns NOERROR
// with 0 records instead of NXDOMAIN
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
resolver := NewResolver()
recordA := nbdns.SimpleRecord{
Name: "example.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
}
recordCNAME := nbdns.SimpleRecord{
Name: "alias.netbird.cloud.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "target.example.com.",
}
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
testCases := []struct {
name string
queryName string
queryType uint16
expectedRcode int
shouldHaveData bool
}{
{
name: "Query A record that exists",
queryName: "example.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query AAAA for domain with only A record",
queryName: "example.netbird.cloud.",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query other record with different case and non-fqdn",
queryName: "EXAMPLE.netbird.cloud",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query TXT for domain with only A record",
queryName: "example.netbird.cloud.",
queryType: dns.TypeTXT,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query A for domain with only CNAME record",
queryName: "alias.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query AAAA for domain with only CNAME record",
queryName: "alias.netbird.cloud.",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query for completely non-existent domain",
queryName: "nonexistent.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeNameError,
shouldHaveData: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG, "Should have received a response message")
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode,
"Response code should be %d (%s)",
tc.expectedRcode, dns.RcodeToString[tc.expectedRcode])
if tc.shouldHaveData {
assert.Greater(t, len(responseMSG.Answer), 0, "Response should contain answers")
} else {
assert.Equal(t, 0, len(responseMSG.Answer), "Response should contain no answers")
}
})
}
}

View File

@@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() {
}
}
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
log.Debugf("extra match domains: %v", s.extraDomains)
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)

View File

@@ -46,9 +46,10 @@ func (w *mocWGIface) Name() string {
}
func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return wgaddr.Address{
IP: netip.MustParseAddr("100.66.100.1"),
Network: netip.MustParsePrefix("100.66.100.0/24"),
IP: ip,
Network: network,
}
}
@@ -463,10 +464,17 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)

View File

@@ -24,15 +24,11 @@ type ServiceViaMemory struct {
}
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: lastIP.String(),
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
runtimePort: defaultPort,
}
return s
@@ -95,7 +91,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().IP.Is6() {
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}

View File

@@ -0,0 +1,33 @@
package dns
import (
"net"
"testing"
nbnet "github.com/netbirdio/netbird/util/net"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -3,7 +3,6 @@ package dns
import (
"context"
"net"
"net/netip"
"syscall"
"time"
@@ -24,8 +23,8 @@ type upstreamResolver struct {
func newUpstreamResolver(
ctx context.Context,
_ string,
_ netip.Addr,
_ netip.Prefix,
_ net.IP,
_ *net.IPNet,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,

View File

@@ -4,7 +4,7 @@ package dns
import (
"context"
"net/netip"
"net"
"time"
"github.com/miekg/dns"
@@ -19,8 +19,8 @@ type upstreamResolver struct {
func newUpstreamResolver(
ctx context.Context,
_ string,
_ netip.Addr,
_ netip.Prefix,
_ net.IP,
_ *net.IPNet,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,

View File

@@ -6,7 +6,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"syscall"
"time"
@@ -19,16 +18,16 @@ import (
type upstreamResolverIOS struct {
*upstreamResolverBase
lIP netip.Addr
lNet netip.Prefix
lIP net.IP
lNet *net.IPNet
interfaceName string
}
func newUpstreamResolver(
ctx context.Context,
interfaceName string,
ip netip.Addr,
net netip.Prefix,
ip net.IP,
net *net.IPNet,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -59,11 +58,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
client.DialTimeout = timeout
upstreamIP, err := netip.ParseAddr(upstreamHost)
if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
@@ -77,7 +73,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
@@ -86,7 +82,7 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: ip.AsSlice(),
IP: ip,
Port: 0, // Let the OS pick a free port
},
Timeout: dialTimeout,

View File

@@ -2,7 +2,7 @@ package dns
import (
"context"
"net/netip"
"net"
"strings"
"testing"
"time"
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {

View File

@@ -121,8 +121,8 @@ type EngineConfig struct {
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
BlockInbound bool
BlockLANAccess bool
LazyConnectionEnabled bool
}
@@ -359,7 +359,6 @@ func (e *Engine) Start() error {
return fmt.Errorf("new wg interface: %w", err)
}
e.wgInterface = wgIface
e.statusRecorder.SetWgIface(wgIface)
// start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey()
@@ -381,6 +380,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("run rosenpass manager: %w", err)
}
}
e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer()
@@ -431,8 +431,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("up wg interface: %w", err)
}
// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
if e.firewall != nil {
e.acl = acl.NewDefaultManager(e.firewall)
}
@@ -488,9 +487,11 @@ func (e *Engine) createFirewall() error {
}
func (e *Engine) initFirewall() error {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
if e.firewall.IsServerRouteSupported() {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
}
}
if e.config.BlockLANAccess {
@@ -524,11 +525,6 @@ func (e *Engine) initFirewall() error {
}
func (e *Engine) blockLanAccess() {
if e.config.BlockInbound {
// no need to set up extra deny rules if inbound is already blocked in general
return
}
var merr *multierror.Error
// TODO: keep this updated
@@ -786,9 +782,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -803,58 +796,56 @@ func isNil(server nbssh.Server) bool {
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Infof("SSH server is disabled because inbound connections are blocked")
return nil
}
if !e.config.ServerSSHAllowed {
log.Info("SSH server is not enabled")
log.Warnf("running SSH server is not permitted")
return nil
}
} else {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
return nil
}
return nil
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
@@ -908,9 +899,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
)
// err = e.mgmClient.Sync(info, e.handleSync)
@@ -1000,29 +988,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
// lazy mgr needs to be aware of which routes are available before they are applied
if e.connMgr != nil {
e.connMgr.UpdateRouteHAMap(clientRoutes)
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
if e.acl != nil {
@@ -1081,8 +1052,17 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(excludedLazyPeers)
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial
@@ -1118,7 +1098,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
Network: prefix.Masked(),
Network: prefix,
Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType),
@@ -1152,7 +1132,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
return entries
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
@@ -1467,7 +1447,6 @@ func (e *Engine) close() {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
}
e.wgInterface = nil
e.statusRecorder.SetWgIface(nil)
}
if !isNil(e.sshServer) {
@@ -1499,9 +1478,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
)
netMap, err := e.mgmClient.GetNetworkMap(info)
@@ -1695,7 +1671,7 @@ func (e *Engine) RunHealthProbes() bool {
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)...,
)
}
@@ -1808,9 +1784,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
}
// GetWgAddr returns the wireguard address
func (e *Engine) GetWgAddr() netip.Addr {
func (e *Engine) GetWgAddr() net.IP {
if e.wgInterface == nil {
return netip.Addr{}
return nil
}
return e.wgInterface.Address().IP
}
@@ -1820,10 +1796,6 @@ func (e *Engine) updateDNSForwarder(
enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry,
) {
if e.config.DisableServerRoutes {
return
}
if !enabled {
if e.dnsForwardMgr == nil {
return
@@ -1879,7 +1851,12 @@ func (e *Engine) Address() (netip.Addr, error) {
return netip.Addr{}, errors.New("wireguard interface not initialized")
}
return e.wgInterface.Address().IP, nil
addr := e.wgInterface.Address()
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
}
return ip.Unmap(), nil
}
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
@@ -1950,8 +1927,18 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
return forwardingRules, nberrors.FormatErrorOrNil(merr)
}
func (e *Engine) toExcludedLazyPeers(rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
excludedPeers := make(map[string]bool)
for _, r := range routes {
if r.Peer == "" {
continue
}
if !excludedPeers[r.Peer] {
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
excludedPeers[r.Peer] = true
}
}
for _, r := range rules {
ip := r.TranslatedAddress
for _, p := range peers {

View File

@@ -86,8 +86,8 @@ type MockWGIface struct {
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
@@ -99,10 +99,6 @@ type MockWGIface struct {
GetNetFunc func() *netstack.Net
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
@@ -147,11 +143,11 @@ func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
@@ -375,8 +371,11 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
@@ -643,12 +642,12 @@ func TestEngine_Sync(t *testing.T) {
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
testCases := []struct {
name string
inputErr error
networkMap *mgmtProto.NetworkMap
expectedLen int
expectedClientRoutes route.HAMap
expectedSerial uint64
name string
inputErr error
networkMap *mgmtProto.NetworkMap
expectedLen int
expectedRoutes []*route.Route
expectedSerial uint64
}{
{
name: "Routes Config Should Be Passed To Manager",
@@ -676,26 +675,22 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
},
},
expectedLen: 2,
expectedClientRoutes: route.HAMap{
"n1|192.168.0.0/24": []*route.Route{
{
ID: "a",
Network: netip.MustParsePrefix("192.168.0.0/24"),
NetID: "n1",
Peer: "p1",
NetworkType: 1,
Masquerade: false,
},
expectedRoutes: []*route.Route{
{
ID: "a",
Network: netip.MustParsePrefix("192.168.0.0/24"),
NetID: "n1",
Peer: "p1",
NetworkType: 1,
Masquerade: false,
},
"n2|192.168.1.0/24": []*route.Route{
{
ID: "b",
Network: netip.MustParsePrefix("192.168.1.0/24"),
NetID: "n2",
Peer: "p1",
NetworkType: 1,
Masquerade: false,
},
{
ID: "b",
Network: netip.MustParsePrefix("192.168.1.0/24"),
NetID: "n2",
Peer: "p1",
NetworkType: 1,
Masquerade: false,
},
},
expectedSerial: 1,
@@ -708,9 +703,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
RemotePeersIsEmpty: false,
Routes: nil,
},
expectedLen: 0,
expectedClientRoutes: nil,
expectedSerial: 1,
expectedLen: 0,
expectedRoutes: []*route.Route{},
expectedSerial: 1,
},
{
name: "Error Shouldn't Break Engine",
@@ -721,9 +716,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
RemotePeersIsEmpty: false,
Routes: nil,
},
expectedLen: 0,
expectedClientRoutes: nil,
expectedSerial: 1,
expectedLen: 0,
expectedRoutes: []*route.Route{},
expectedSerial: 1,
},
}
@@ -766,29 +761,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.wgInterface, err = iface.NewWGIFace(opts)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
clientRoutes route.HAMap
inputSerial uint64
inputRoutes []*route.Route
}{}
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error {
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
input.inputSerial = updateSerial
input.clientRoutes = clientRoutes
input.inputRoutes = newRoutes
return testCase.inputErr
},
ClassifyRoutesFunc: func(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
if len(newRoutes) == 0 {
return nil, nil
}
// Classify all routes as client routes (not matching our public key)
clientRoutes := make(route.HAMap)
for _, r := range newRoutes {
haID := r.GetHAUniqueID()
clientRoutes[haID] = append(clientRoutes[haID], r)
}
return nil, clientRoutes
},
}
engine.routeManager = mockRouteManager
@@ -806,8 +788,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedClientRoutes, input.clientRoutes, "clientRoutes should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
})
}
}
@@ -968,7 +950,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error {
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}

View File

@@ -28,8 +28,8 @@ type wgIfaceBase interface {
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter
@@ -37,5 +37,4 @@ type wgIfaceBase interface {
GetWGDevice() *wgdevice.Device
GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
}

View File

@@ -6,7 +6,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
@@ -14,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
)
const (
@@ -39,9 +37,7 @@ type Config struct {
// - Managing inactivity monitors for lazy connections (based on peer disconnection events)
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
// - Managing route HA groups and activating all peers in a group when one peer is activated
type Manager struct {
engineCtx context.Context
peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration
@@ -55,20 +51,13 @@ type Manager struct {
activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
// Route HA group management
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex // protects route mappings
cancel context.CancelFunc
onInactive chan peerid.ConnID
}
// NewManager creates a new lazy connection manager
// engineCtx is the context for creating peer Connection
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
log.Infof("setup lazy connection service")
m := &Manager{
engineCtx: engineCtx,
peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold,
@@ -77,8 +66,6 @@ func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.S
excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
peerToHAGroups: make(map[string][]route.HAUniqueID),
haGroupToPeers: make(map[route.HAUniqueID][]string),
onInactive: make(chan peerid.ConnID),
}
@@ -100,45 +87,11 @@ func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.S
return m
}
// UpdateRouteHAMap updates the HA group mappings for routes
// This should be called when route configuration changes
func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
m.routesMu.Lock()
defer m.routesMu.Unlock()
maps.Clear(m.peerToHAGroups)
maps.Clear(m.haGroupToPeers)
for haUniqueID, routes := range haMap {
var peers []string
peerSet := make(map[string]bool)
for _, r := range routes {
if !peerSet[r.Peer] {
peerSet[r.Peer] = true
peers = append(peers, r.Peer)
}
}
if len(peers) <= 1 {
continue
}
m.haGroupToPeers[haUniqueID] = peers
for _, peerID := range peers {
m.peerToHAGroups[peerID] = append(m.peerToHAGroups[peerID], haUniqueID)
}
}
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes",
len(m.haGroupToPeers), len(m.peerToHAGroups))
}
// Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) {
defer m.close()
ctx, m.cancel = context.WithCancel(ctx)
for {
select {
case <-ctx.Done():
@@ -256,47 +209,25 @@ func (m *Manager) RemovePeer(peerID string) {
}
// ActivatePeer activates a peer connection when a signal message is received
// Also activates all peers in the same HA groups as this peer
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
return false
}
if !m.activateSinglePeer(ctx, cfg, mp) {
return false
}
m.activateHAGroupPeers(ctx, peerID)
return true
}
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
// Returns nil values if the peer should be skipped
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
cfg, ok := m.managedPeers[peerID]
if !ok {
return nil, nil
return false
}
mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
return nil, nil
return false
}
// signal messages coming continuously after success activation, with this avoid the multiple activation
if mp.expectedWatcher == watcherInactivity {
return nil, nil
return false
}
return cfg, mp
}
// activateSinglePeer activates a single peer (internal method)
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
@@ -307,53 +238,12 @@ func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConf
return false
}
cfg.Log.Infof("starting inactivity monitor")
mp.peerCfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive)
return true
}
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggerPeerID]
m.routesMu.RUnlock()
if len(haGroups) == 0 {
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
return
}
activatedCount := 0
for _, haGroup := range haGroups {
m.routesMu.RLock()
peers := m.haGroupToPeers[haGroup]
m.routesMu.RUnlock()
for _, peerID := range peers {
if peerID == triggerPeerID {
continue
}
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
continue
}
if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
}
}
if activatedCount > 0 {
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
activatedCount, triggerPeerID, haGroups)
}
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
@@ -397,6 +287,8 @@ func (m *Manager) close() {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.cancel()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
@@ -405,13 +297,6 @@ func (m *Manager) close() {
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
// Clear route mappings
m.routesMu.Lock()
m.peerToHAGroups = make(map[string][]route.HAUniqueID)
m.haGroupToPeers = make(map[route.HAUniqueID][]string)
m.routesMu.Unlock()
log.Infof("lazy connection manager closed")
}
@@ -432,13 +317,12 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
mp.peerCfg.Log.Infof("detected peer activity")
if !m.activateSinglePeer(ctx, mp.peerCfg, mp) {
return
}
mp.expectedWatcher = watcherInactivity
m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("starting inactivity monitor")
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {

View File

@@ -116,9 +116,6 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
)
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, err
@@ -142,9 +139,6 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
// fallback if mark rules are not in place
wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
}
// mapRxPackets maps packet counts to RX based on flow direction
@@ -293,15 +293,17 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
// fallback if marks are not set
wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
switch {
case wgaddr == srcIP:
case wgaddr.Equal(src):
return nftypes.Egress
case wgaddr == dstIP:
case wgaddr.Equal(dst):
return nftypes.Ingress
case wgnetwork.Contains(srcIP):
case wgnetwork.Contains(src):
// netbird network -> resource network
return nftypes.Ingress
case wgnetwork.Contains(dstIP):
case wgnetwork.Contains(dst):
// resource network -> netbird network
return nftypes.Egress
}

View File

@@ -2,7 +2,7 @@ package logger
import (
"context"
"net/netip"
"net"
"sync"
"sync/atomic"
"time"
@@ -23,16 +23,17 @@ type Logger struct {
rcvChan atomic.Pointer[rcvChan]
cancel context.CancelFunc
statusRecorder *peer.Status
wgIfaceNet netip.Prefix
wgIfaceIPNet net.IPNet
dnsCollection atomic.Bool
exitNodeCollection atomic.Bool
Store types.Store
}
func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
return &Logger{
statusRecorder: statusRecorder,
wgIfaceNet: wgIfaceIPNet,
wgIfaceIPNet: wgIfaceIPNet,
Store: store.NewMemoryStore(),
}
}
@@ -88,11 +89,11 @@ func (l *Logger) startReceiver() {
var isSrcExitNode bool
var isDestExitNode bool
if !l.wgIfaceNet.Contains(event.SourceIP) {
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
}
if !l.wgIfaceNet.Contains(event.DestIP) {
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
}

View File

@@ -1,7 +1,7 @@
package logger_test
import (
"net/netip"
"net"
"testing"
"time"
@@ -12,7 +12,7 @@ import (
)
func TestStore(t *testing.T) {
logger := logger.New(nil, netip.Prefix{})
logger := logger.New(nil, net.IPNet{})
logger.Enable()
event := types.EventFields{

View File

@@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/netip"
"net"
"runtime"
"sync"
"time"
@@ -34,11 +34,11 @@ type Manager struct {
// NewManager creates a new netflow manager
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var prefix netip.Prefix
var ipNet net.IPNet
if iface != nil {
prefix = iface.Address().Network
ipNet = *iface.Address().Network
}
flowLogger := logger.New(statusRecorder, prefix)
flowLogger := logger.New(statusRecorder, ipNet)
var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {

View File

@@ -1,7 +1,7 @@
package netflow
import (
"net/netip"
"net"
"testing"
"time"
@@ -33,7 +33,10 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
func TestManager_Update(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: netip.MustParsePrefix("192.168.1.1/32"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
},
isUserspaceBind: true,
}
@@ -99,7 +102,10 @@ func TestManager_Update(t *testing.T) {
func TestManager_Update_TokenPreservation(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: netip.MustParsePrefix("192.168.1.1/32"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
},
isUserspaceBind: true,
}

View File

@@ -1,9 +1,7 @@
package peer
import (
"context"
"errors"
"fmt"
"net/netip"
"slices"
"sync"
@@ -33,21 +31,10 @@ type ResolvedDomainInfo struct {
ParentDomain domain.Domain
}
type WGIfaceStatus interface {
FullStats() (*configurer.Stats, error)
}
type EventListener interface {
OnEvent(event *proto.SystemEvent)
}
// RouterState status for router peers. This contains relevant fields for route manager
type RouterState struct {
Status ConnStatus
Relayed bool
Latency time.Duration
}
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
@@ -159,32 +146,11 @@ type FullStatus struct {
LazyConnectionEnabled bool
}
type StatusChangeSubscription struct {
peerID string
id string
eventsChan chan map[string]RouterState
ctx context.Context
}
func newStatusChangeSubscription(ctx context.Context, peerID string) *StatusChangeSubscription {
return &StatusChangeSubscription{
ctx: ctx,
peerID: peerID,
id: uuid.New().String(),
// it is a buffer for notifications to block less the status recorded
eventsChan: make(chan map[string]RouterState, 8),
}
}
func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
return s.eventsChan
}
// Status holds a state of peers, signal, management connections and relays
type Status struct {
mux sync.Mutex
peers map[string]State
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
@@ -215,14 +181,13 @@ type Status struct {
ingressGwMgr *ingressgw.Manager
routeIDLookup routeIDLookup
wgIface WGIfaceStatus
}
// NewRecorder returns a new Status instance
func NewRecorder(mgmAddress string) *Status {
return &Status{
peers: make(map[string]State),
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
changeNotify: make(map[string]chan struct{}),
eventStreams: make(map[string]chan *proto.SystemEvent),
eventQueue: NewEventQueue(eventQueueSize),
offlinePeers: make([]State, 0),
@@ -324,7 +289,11 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return errors.New("peer doesn't exist")
}
oldState := peerState.ConnStatus
if receivedState.IP != "" {
peerState.IP = receivedState.IP
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
peerState.ConnStatus = receivedState.ConnStatus
@@ -340,14 +309,11 @@ func (d *Status) UpdatePeerState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
if skipNotification {
return nil
}
// when we close the connection we will not notify the router manager
if receivedState.ConnStatus == StatusIdle {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
d.notifyPeerListChanged()
return nil
}
@@ -414,8 +380,11 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return errors.New("peer doesn't exist")
}
oldState := peerState.ConnStatus
oldIsRelayed := peerState.Relayed
if receivedState.IP != "" {
peerState.IP = receivedState.IP
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
@@ -428,13 +397,12 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
if skipNotification {
return nil
}
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -447,8 +415,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return errors.New("peer doesn't exist")
}
oldState := peerState.ConnStatus
oldIsRelayed := peerState.Relayed
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
@@ -458,13 +425,12 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
if skipNotification {
return nil
}
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -477,8 +443,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return errors.New("peer doesn't exist")
}
oldState := peerState.ConnStatus
oldIsRelayed := peerState.Relayed
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.Relayed = receivedState.Relayed
@@ -487,13 +452,12 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
if skipNotification {
return nil
}
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -506,8 +470,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return errors.New("peer doesn't exist")
}
oldState := peerState.ConnStatus
oldIsRelayed := peerState.Relayed
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
peerState.ConnStatus = receivedState.ConnStatus
peerState.Relayed = receivedState.Relayed
@@ -519,13 +482,12 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
if skipNotification {
return nil
}
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -548,12 +510,17 @@ func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGSt
return nil
}
func hasStatusOrRelayedChange(oldConnStatus, newConnStatus ConnStatus, oldRelayed, newRelayed bool) bool {
return oldRelayed != newRelayed || hasConnStatusChanged(newConnStatus, oldConnStatus)
}
func hasConnStatusChanged(oldStatus, newStatus ConnStatus) bool {
return newStatus != oldStatus
func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
switch {
case receivedConnStatus == StatusConnecting:
return true
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusConnecting:
return true
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusIdle:
return curr.IP != ""
default:
return false
}
}
// UpdatePeerFQDN update peer's state fqdn only
@@ -584,47 +551,21 @@ func (d *Status) FinishPeerListModifications() {
d.mux.Unlock()
d.notifyPeerListChanged()
for key := range d.peers {
d.notifyPeerStateChangeListeners(key)
}
}
func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock()
defer d.mux.Unlock()
sub := newStatusChangeSubscription(ctx, peerID)
if _, ok := d.changeNotify[peerID]; !ok {
d.changeNotify[peerID] = make(map[string]*StatusChangeSubscription)
}
d.changeNotify[peerID][sub.id] = sub
return sub
}
func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscription) {
d.mux.Lock()
defer d.mux.Unlock()
if subscription == nil {
return
ch, found := d.changeNotify[peer]
if found {
return ch
}
channels, ok := d.changeNotify[subscription.peerID]
if !ok {
return
}
sub, exists := channels[subscription.id]
if !exists {
return
}
delete(channels, subscription.id)
if len(channels) == 0 {
delete(d.changeNotify, sub.peerID)
}
ch = make(chan struct{})
d.changeNotify[peer] = ch
return ch
}
// GetLocalPeerState returns the local peer state
@@ -999,33 +940,13 @@ func (d *Status) onConnectionChanged() {
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
subs, ok := d.changeNotify[peerID]
if !ok {
ch, found := d.changeNotify[peerID]
if !found {
return
}
// collect the relevant data for router peers
routerPeers := make(map[string]RouterState, len(d.changeNotify))
for pid := range d.changeNotify {
s, ok := d.peers[pid]
if !ok {
log.Warnf("router peer not found in peers list: %s", pid)
continue
}
routerPeers[pid] = RouterState{
Status: s.ConnStatus,
Relayed: s.Relayed,
Latency: s.Latency,
}
}
for _, sub := range subs {
select {
case sub.eventsChan <- routerPeers:
case <-sub.ctx.Done():
}
}
close(ch)
delete(d.changeNotify, peerID)
}
func (d *Status) notifyPeerListChanged() {
@@ -1109,23 +1030,6 @@ func (d *Status) GetEventHistory() []*proto.SystemEvent {
return d.eventQueue.GetAll()
}
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
d.mux.Lock()
defer d.mux.Unlock()
d.wgIface = wgInterface
}
func (d *Status) PeersStatus() (*configurer.Stats, error) {
d.mux.Lock()
defer d.mux.Unlock()
if d.wgIface == nil {
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
}
return d.wgIface.FullStats()
}
type EventQueue struct {
maxSize int
events []*proto.SystemEvent

View File

@@ -1,11 +1,9 @@
package peer
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@@ -44,16 +42,16 @@ func TestGetPeer(t *testing.T) {
func TestUpdatePeerState(t *testing.T) {
key := "abc"
ip := "10.10.10.10"
fqdn := "peer-a.netbird.local"
status := NewRecorder("https://mgm")
_ = status.AddPeer(key, fqdn, ip)
peerState := State{
PubKey: key,
ConnStatusUpdate: time.Now(),
ConnStatus: StatusConnecting,
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
peerState.IP = ip
err := status.UpdatePeerState(peerState)
assert.NoError(t, err, "shouldn't return error")
@@ -85,27 +83,25 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
key := "abc"
ip := "10.10.10.10"
status := NewRecorder("https://mgm")
_ = status.AddPeer(key, "abc.netbird", ip)
sub := status.SubscribeToPeerStateChanges(context.Background(), key)
assert.NotNil(t, sub, "channel shouldn't be nil")
peerState := State{
PubKey: key,
ConnStatus: StatusConnecting,
Relayed: false,
ConnStatusUpdate: time.Now(),
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
ch := status.GetPeerStateChangeNotifier(key)
assert.NotNil(t, ch, "channel shouldn't be nil")
peerState.IP = ip
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
assert.NoError(t, err, "shouldn't return error")
timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
select {
case <-sub.eventsChan:
case <-timeoutCtx.Done():
t.Errorf("timed out waiting for event")
case <-ch:
default:
t.Errorf("channel wasn't closed after update")
}
}

View File

@@ -170,7 +170,7 @@ func ProbeAll(
var wg sync.WaitGroup
for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
wg.Add(1)

View File

@@ -0,0 +1,544 @@
package routemanager
import (
"context"
"fmt"
"reflect"
"runtime"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
)
const (
handlerTypeDynamic = iota
handlerTypeDomain
handlerTypeStatic
)
type reason int
const (
reasonUnknown reason = iota
reasonRouteUpdate
reasonPeerUpdate
reasonShutdown
)
type routerPeerStatus struct {
connected bool
relayed bool
latency time.Duration
}
type routesUpdate struct {
updateSerial uint64
routes []*route.Route
}
// RouteHandler defines the interface for handling routes
type RouteHandler interface {
String() string
AddRoute(ctx context.Context) error
RemoveRoute() error
AddAllowedIPs(peerKey string) error
RemoveAllowedIPs() error
}
type clientNetwork struct {
ctx context.Context
cancel context.CancelFunc
statusRecorder *peer.Status
wgInterface iface.WGIface
routes map[route.ID]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{}
currentChosen *route.Route
handler RouteHandler
updateSerial uint64
}
func newClientNetworkWatcher(
ctx context.Context,
dnsRouteInterval time.Duration,
wgInterface iface.WGIface,
statusRecorder *peer.Status,
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
ctx: ctx,
cancel: cancel,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouteInterval,
statusRecorder,
wgInterface,
dnsServer,
peerStore,
useNewDNSRoute,
),
}
return client
}
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses := make(map[route.ID]routerPeerStatus)
for _, r := range c.routes {
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
if err != nil {
log.Debugf("couldn't fetch peer state: %v", err)
continue
}
routePeerStatuses[r.ID] = routerPeerStatus{
connected: peerStatus.ConnStatus == peer.StatusConnected,
relayed: peerStatus.Relayed,
latency: peerStatus.Latency,
}
}
return routePeerStatuses
}
// getBestRouteFromStatuses determines the most optimal route from the available routes
// within a clientNetwork, taking into account peer connection status, route metrics, and
// preference for non-relayed and direct connections.
//
// It follows these prioritization rules:
// * Connected peers: Only routes with connected peers are considered.
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
// * Latency: Routes with lower latency are prioritized.
// * we compare the current score + 10ms to the chosen score to avoid flapping between routes
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
chosen := route.ID("")
chosenScore := float64(0)
currScore := float64(0)
currID := route.ID("")
if c.currentChosen != nil {
currID = c.currentChosen.ID
}
for _, r := range c.routes {
tempScore := float64(0)
peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected {
continue
}
if r.Metric < route.MaxMetric {
metricDiff := route.MaxMetric - r.Metric
tempScore = float64(metricDiff) * 10
}
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := 999 * time.Millisecond
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
}
// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}
// higher latency is worse score
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
tempScore++
}
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
chosen = r.ID
chosenScore = tempScore
}
if chosen == "" && currID == "" {
chosen = r.ID
chosenScore = tempScore
}
if r.ID == currID {
currScore = tempScore
}
}
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
switch {
case chosen == "":
var peers []string
for _, r := range c.routes {
peers = append(peers, r.Peer)
}
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
return currID
}
var p string
if rt := c.routes[chosen]; rt != nil {
p = rt.Peer
}
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
}
return chosen
}
func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
for {
select {
case <-ctx.Done():
return
case <-closer:
return
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil {
continue
}
peerStateUpdate <- struct{}{}
log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus)
}
}
}
func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer]
if found {
continue
}
closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
}
}
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
}
return nil
}
func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error {
if c.currentChosen == nil {
return nil
}
var merr *multierror.Error
if err := c.removeRouteFromWireGuardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
}
c.disconnectEvent(rsn)
return nberrors.FormatErrorOrNil(merr)
}
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error {
routerPeerStatuses := c.getRouterPeerStatuses()
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system
if newChosenID == "" {
if err := c.removeRouteFromPeerAndSystem(rsn); err != nil {
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
}
c.currentChosen = nil
return nil
}
// If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
c.currentChosen.Equal(c.routes[newChosenID]) {
return nil
}
var isNew bool
if c.currentChosen == nil {
// If they were not previously assigned to another peer, add routes to the system first
if err := c.handler.AddRoute(c.ctx); err != nil {
return fmt.Errorf("add route: %w", err)
}
isNew = true
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
}
c.currentChosen = c.routes[newChosenID]
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
if isNew {
c.connectEvent()
}
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}
return nil
}
func (c *clientNetwork) connectEvent() {
var defaultRoute bool
for _, r := range c.routes {
if r.Network.Bits() == 0 {
defaultRoute = true
break
}
}
if !defaultRoute {
return
}
meta := map[string]string{
"network": c.handler.String(),
}
if c.currentChosen != nil {
meta["id"] = string(c.currentChosen.NetID)
meta["peer"] = c.currentChosen.Peer
}
c.statusRecorder.PublishEvent(
proto.SystemEvent_INFO,
proto.SystemEvent_NETWORK,
"Default route added",
"Exit node connected.",
meta,
)
}
func (c *clientNetwork) disconnectEvent(rsn reason) {
var defaultRoute bool
for _, r := range c.routes {
if r.Network.Bits() == 0 {
defaultRoute = true
break
}
}
if !defaultRoute {
return
}
var severity proto.SystemEvent_Severity
var message string
var userMessage string
meta := make(map[string]string)
if c.currentChosen != nil {
meta["id"] = string(c.currentChosen.NetID)
meta["peer"] = c.currentChosen.Peer
}
meta["network"] = c.handler.String()
switch rsn {
case reasonShutdown:
severity = proto.SystemEvent_INFO
message = "Default route removed"
userMessage = "Exit node disconnected."
case reasonRouteUpdate:
severity = proto.SystemEvent_INFO
message = "Default route updated due to configuration change"
case reasonPeerUpdate:
severity = proto.SystemEvent_WARNING
message = "Default route disconnected due to peer unreachability"
userMessage = "Exit node connection lost. Your internet access might be affected."
default:
severity = proto.SystemEvent_ERROR
message = "Default route disconnected for unknown reasons"
userMessage = "Exit node disconnected for unknown reasons."
}
c.statusRecorder.PublishEvent(
severity,
proto.SystemEvent_NETWORK,
message,
userMessage,
meta,
)
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() {
c.routeUpdate <- update
}()
}
func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
isUpdateMapDifferent := false
updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes {
updateMap[r.ID] = r
}
if len(c.routes) != len(updateMap) {
isUpdateMapDifferent = true
}
for id, r := range c.routes {
_, found := updateMap[id]
if !found {
close(c.routePeersNotifiers[r.Peer])
delete(c.routePeersNotifiers, r.Peer)
isUpdateMapDifferent = true
continue
}
if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
isUpdateMapDifferent = true
}
}
c.routes = updateMap
return isUpdateMapDifferent
}
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
// All the processing related to the client network should be done here. Thread-safe.
func (c *clientNetwork) peersStateAndUpdateWatcher() {
for {
select {
case <-c.ctx.Done():
log.Debugf("Stopping watcher for network [%v]", c.handler)
if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
}
return
case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate)
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial {
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
continue
}
log.Debugf("Received a new client network route update for [%v]", c.handler)
// hash update somehow
isTrueRouteUpdate := c.handleUpdate(update)
c.updateSerial = update.updateSerial
if isTrueRouteUpdate {
log.Debug("Client network update contains different routes, recalculating routes")
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate)
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
} else {
log.Debug("Route update is not different, skipping route recalculation")
}
c.startPeersStatusChangeWatcher()
}
}
}
func handlerFromRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) RouteHandler {
switch handlerType(rt, useNewDNSRoute) {
case handlerTypeDomain:
return dnsinterceptor.New(
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
peerStore,
)
case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
default:
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}
}
func handlerType(rt *route.Route, useNewDNSRoute bool) int {
if !rt.IsDynamic() {
return handlerTypeStatic
}
if useNewDNSRoute && runtime.GOOS != "ios" {
return handlerTypeDomain
}
return handlerTypeDynamic
}

View File

@@ -1,603 +0,0 @@
package client
import (
"context"
"fmt"
"reflect"
"runtime"
"time"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
)
const (
handlerTypeDynamic = iota
handlerTypeDomain
handlerTypeStatic
)
type reason int
const (
reasonUnknown reason = iota
reasonRouteUpdate
reasonPeerUpdate
reasonShutdown
reasonHA
)
type routerPeerStatus struct {
status peer.ConnStatus
relayed bool
latency time.Duration
}
type RoutesUpdate struct {
UpdateSerial uint64
Routes []*route.Route
}
// RouteHandler defines the interface for handling routes
type RouteHandler interface {
String() string
AddRoute(ctx context.Context) error
RemoveRoute() error
AddAllowedIPs(peerKey string) error
RemoveAllowedIPs() error
}
type WatcherConfig struct {
Context context.Context
DNSRouteInterval time.Duration
WGInterface iface.WGIface
StatusRecorder *peer.Status
Route *route.Route
Handler RouteHandler
}
// Watcher watches route and peer changes and updates allowed IPs accordingly.
// Once stopped, it cannot be reused.
// The methods are not thread-safe and should be synchronized externally.
type Watcher struct {
ctx context.Context
cancel context.CancelFunc
statusRecorder *peer.Status
wgInterface iface.WGIface
routes map[route.ID]*route.Route
routeUpdate chan RoutesUpdate
peerStateUpdate chan map[string]peer.RouterState
routePeersNotifiers map[string]chan struct{} // map of peer key to channel for peer state changes
currentChosen *route.Route
currentChosenStatus *routerPeerStatus
handler RouteHandler
updateSerial uint64
}
func NewWatcher(config WatcherConfig) *Watcher {
ctx, cancel := context.WithCancel(config.Context)
client := &Watcher{
ctx: ctx,
cancel: cancel,
statusRecorder: config.StatusRecorder,
wgInterface: config.WGInterface,
routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan RoutesUpdate),
peerStateUpdate: make(chan map[string]peer.RouterState),
handler: config.Handler,
currentChosenStatus: nil,
}
return client
}
func (w *Watcher) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses := make(map[route.ID]routerPeerStatus)
for _, r := range w.routes {
peerStatus, err := w.statusRecorder.GetPeer(r.Peer)
if err != nil {
log.Debugf("couldn't fetch peer state %v: %v", r.Peer, err)
continue
}
routePeerStatuses[r.ID] = routerPeerStatus{
status: peerStatus.ConnStatus,
relayed: peerStatus.Relayed,
latency: peerStatus.Latency,
}
}
return routePeerStatuses
}
func (w *Watcher) convertRouterPeerStatuses(states map[string]peer.RouterState) map[route.ID]routerPeerStatus {
routePeerStatuses := make(map[route.ID]routerPeerStatus)
for _, r := range w.routes {
peerStatus, ok := states[r.Peer]
if !ok {
log.Warnf("couldn't fetch peer state: %v", r.Peer)
continue
}
routePeerStatuses[r.ID] = routerPeerStatus{
status: peerStatus.Status,
relayed: peerStatus.Relayed,
latency: peerStatus.Latency,
}
}
return routePeerStatuses
}
// getBestRouteFromStatuses determines the most optimal route from the available routes
// within a Watcher, taking into account peer connection status, route metrics, and
// preference for non-relayed and direct connections.
//
// It follows these prioritization rules:
// * Connection status: Both connected and idle peers are considered, but connected peers always take precedence.
// * Idle peer penalty: Idle peers receive a significant score penalty to ensure any connected peer is preferred.
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
// * Latency: Routes with lower latency are prioritized.
// * Allowed IPs: Idle peers can still receive allowed IPs to enable lazy connection triggering.
// * we compare the current score + 10ms to the chosen score to avoid flapping between routes
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
//
// It returns the ID of the selected optimal route.
func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) (route.ID, routerPeerStatus) {
var chosen route.ID
chosenScore := float64(0)
currScore := float64(0)
var currID route.ID
if w.currentChosen != nil {
currID = w.currentChosen.ID
}
var chosenStatus routerPeerStatus
for _, r := range w.routes {
tempScore := float64(0)
peerStatus, found := routePeerStatuses[r.ID]
// connecting status equals disconnected: no wireguard endpoint to assign allowed IPs to
if !found || peerStatus.status == peer.StatusConnecting {
continue
}
if r.Metric < route.MaxMetric {
metricDiff := route.MaxMetric - r.Metric
tempScore = float64(metricDiff) * 10
}
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := 999 * time.Millisecond
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else if !peerStatus.relayed && peerStatus.status != peer.StatusIdle {
log.Tracef("peer %s has 0 latency: [%v]", r.Peer, w.handler)
}
// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}
// higher latency is worse score
tempScore += 1 - latency.Seconds()
// apply significant penalty for idle peers to ensure connected peers always take precedence
if peerStatus.status == peer.StatusConnected {
tempScore += 100_000
}
if !peerStatus.relayed {
tempScore++
}
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
chosen = r.ID
chosenStatus = peerStatus
chosenScore = tempScore
}
if chosen == "" && currID == "" {
chosen = r.ID
chosenStatus = peerStatus
chosenScore = tempScore
}
if r.ID == currID {
currScore = tempScore
}
}
chosenID := chosen
if chosen == "" {
chosenID = "<none>"
}
currentID := currID
if currID == "" {
currentID = "<none>"
}
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosenID, chosenScore, currentID, currScore)
switch {
case chosen == "":
var peers []string
for _, r := range w.routes {
peers = append(peers, r.Peer)
}
log.Infof("network [%v] has not been assigned a routing peer as no peers from the list %s are currently available", w.handler, peers)
case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("keeping current routing peer %s for [%v]: the score difference with latency is less than 0.01(10ms): current: %f, new: %f",
w.currentChosen.Peer, w.handler, currScore, chosenScore)
return currID, chosenStatus
}
var p string
if rt := w.routes[chosen]; rt != nil {
p = rt.Peer
}
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, w.handler)
}
return chosen, chosenStatus
}
func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan map[string]peer.RouterState, closer chan struct{}) {
subscription := w.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey)
defer w.statusRecorder.UnsubscribePeerStateChanges(subscription)
for {
select {
case <-ctx.Done():
return
case <-closer:
return
case routerStates := <-subscription.Events():
peerStateUpdate <- routerStates
log.Debugf("triggered route state update for Peer: %s", peerKey)
}
}
}
func (w *Watcher) startNewPeerStatusWatchers() {
for _, r := range w.routes {
if _, found := w.routePeersNotifiers[r.Peer]; found {
continue
}
closerChan := make(chan struct{})
w.routePeersNotifiers[r.Peer] = closerChan
go w.watchPeerStatusChanges(w.ctx, r.Peer, w.peerStateUpdate, closerChan)
}
}
// addAllowedIPs adds the allowed IPs for the current chosen route to the handler.
func (w *Watcher) addAllowedIPs(route *route.Route) error {
if err := w.handler.AddAllowedIPs(route.Peer); err != nil {
return fmt.Errorf("add allowed IPs for peer %s: %w", route.Peer, err)
}
if err := w.statusRecorder.AddPeerStateRoute(route.Peer, w.handler.String(), route.GetResourceID()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
w.connectEvent(route)
return nil
}
func (w *Watcher) removeAllowedIPs(route *route.Route, rsn reason) error {
if err := w.statusRecorder.RemovePeerStateRoute(route.Peer, w.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if err := w.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
}
w.disconnectEvent(route, rsn)
return nil
}
// shouldSkipRecalculation checks if we can skip route recalculation for the same route without status changes
func (w *Watcher) shouldSkipRecalculation(newChosenID route.ID, newStatus routerPeerStatus) bool {
if w.currentChosen == nil {
return false
}
isSameRoute := w.currentChosen.ID == newChosenID && w.currentChosen.Equal(w.routes[newChosenID])
if !isSameRoute {
return false
}
if w.currentChosenStatus != nil {
return w.currentChosenStatus.status == newStatus.status
}
return true
}
func (w *Watcher) recalculateRoutes(rsn reason, routerPeerStatuses map[route.ID]routerPeerStatus) error {
newChosenID, newStatus := w.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer
if newChosenID == "" {
if w.currentChosen == nil {
return nil
}
if err := w.removeAllowedIPs(w.currentChosen, rsn); err != nil {
return fmt.Errorf("remove obsolete: %w", err)
}
w.currentChosen = nil
w.currentChosenStatus = nil
return nil
}
// If we can skip recalculation for the same route without changes, do nothing
if w.shouldSkipRecalculation(newChosenID, newStatus) {
return nil
}
// If the chosen route was assigned to a different peer, remove the allowed IPs first
if isNew := w.currentChosen == nil; !isNew {
if err := w.removeAllowedIPs(w.currentChosen, reasonHA); err != nil {
return fmt.Errorf("remove old: %w", err)
}
}
newChosenRoute := w.routes[newChosenID]
if err := w.addAllowedIPs(newChosenRoute); err != nil {
return fmt.Errorf("add new: %w", err)
}
if newStatus.status != peer.StatusIdle {
w.connectEvent(newChosenRoute)
}
w.currentChosen = newChosenRoute
w.currentChosenStatus = &newStatus
return nil
}
func (w *Watcher) connectEvent(route *route.Route) {
var defaultRoute bool
for _, r := range w.routes {
if r.Network.Bits() == 0 {
defaultRoute = true
break
}
}
if !defaultRoute {
return
}
meta := map[string]string{
"network": w.handler.String(),
}
if route != nil {
meta["id"] = string(route.NetID)
meta["peer"] = route.Peer
}
w.statusRecorder.PublishEvent(
proto.SystemEvent_INFO,
proto.SystemEvent_NETWORK,
"Default route added",
"Exit node connected.",
meta,
)
}
func (w *Watcher) disconnectEvent(route *route.Route, rsn reason) {
var defaultRoute bool
for _, r := range w.routes {
if r.Network.Bits() == 0 {
defaultRoute = true
break
}
}
if !defaultRoute {
return
}
var severity proto.SystemEvent_Severity
var message string
var userMessage string
meta := make(map[string]string)
if route != nil {
meta["id"] = string(route.NetID)
meta["peer"] = route.Peer
}
meta["network"] = w.handler.String()
switch rsn {
case reasonShutdown:
severity = proto.SystemEvent_INFO
message = "Default route removed"
userMessage = "Exit node disconnected."
case reasonRouteUpdate:
severity = proto.SystemEvent_INFO
message = "Default route updated due to configuration change"
case reasonPeerUpdate:
severity = proto.SystemEvent_WARNING
message = "Default route disconnected due to peer unreachability"
userMessage = "Exit node connection lost. Your internet access might be affected."
case reasonHA:
severity = proto.SystemEvent_INFO
message = "Default route disconnected due to high availability change"
userMessage = "Exit node disconnected due to high availability change."
default:
severity = proto.SystemEvent_ERROR
message = "Default route disconnected for unknown reasons"
userMessage = "Exit node disconnected for unknown reasons."
}
w.statusRecorder.PublishEvent(
severity,
proto.SystemEvent_NETWORK,
message,
userMessage,
meta,
)
}
func (w *Watcher) SendUpdate(update RoutesUpdate) {
go func() {
select {
case w.routeUpdate <- update:
case <-w.ctx.Done():
}
}()
}
func (w *Watcher) classifyUpdate(update RoutesUpdate) bool {
isUpdateMapDifferent := false
updateMap := make(map[route.ID]*route.Route)
for _, r := range update.Routes {
updateMap[r.ID] = r
}
if len(w.routes) != len(updateMap) {
isUpdateMapDifferent = true
}
for id, r := range w.routes {
_, found := updateMap[id]
if !found {
close(w.routePeersNotifiers[r.Peer])
delete(w.routePeersNotifiers, r.Peer)
isUpdateMapDifferent = true
continue
}
if !reflect.DeepEqual(w.routes[id], updateMap[id]) {
isUpdateMapDifferent = true
}
}
w.routes = updateMap
return isUpdateMapDifferent
}
// Start is the main point of reacting on client network routing events.
// All the processing related to the client network should be done here. Thread-safe.
func (w *Watcher) Start() {
for {
select {
case <-w.ctx.Done():
return
case routersStates := <-w.peerStateUpdate:
routerPeerStatuses := w.convertRouterPeerStatuses(routersStates)
if err := w.recalculateRoutes(reasonPeerUpdate, routerPeerStatuses); err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", w.handler, err)
}
case update := <-w.routeUpdate:
if update.UpdateSerial < w.updateSerial {
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", w.updateSerial, update.UpdateSerial)
continue
}
w.handleRouteUpdate(update)
}
}
}
func (w *Watcher) handleRouteUpdate(update RoutesUpdate) {
log.Debugf("Received a new client network route update for [%v]", w.handler)
// hash update somehow
isTrueRouteUpdate := w.classifyUpdate(update)
w.updateSerial = update.UpdateSerial
if isTrueRouteUpdate {
log.Debugf("client network update %v for [%v] contains different routes, recalculating routes", update.UpdateSerial, w.handler)
routePeerStatuses := w.getRouterPeerStatuses()
if err := w.recalculateRoutes(reasonRouteUpdate, routePeerStatuses); err != nil {
log.Errorf("failed to recalculate routes for network [%v]: %v", w.handler, err)
}
} else {
log.Debugf("route update %v for [%v] is not different, skipping route recalculation", update.UpdateSerial, w.handler)
}
w.startNewPeerStatusWatchers()
}
// Stop stops the watcher and cleans up resources.
func (w *Watcher) Stop() {
log.Debugf("Stopping watcher for network [%v]", w.handler)
w.cancel()
if w.currentChosen == nil {
return
}
if err := w.removeAllowedIPs(w.currentChosen, reasonShutdown); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", w.handler, err)
}
w.currentChosenStatus = nil
}
func HandlerFromRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) RouteHandler {
switch handlerType(rt, useNewDNSRoute) {
case handlerTypeDomain:
return dnsinterceptor.New(
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
peerStore,
)
case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
default:
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}
}
func handlerType(rt *route.Route, useNewDNSRoute bool) int {
if !rt.IsDynamic() {
return handlerTypeStatic
}
if useNewDNSRoute && runtime.GOOS != "ios" {
return handlerTypeDomain
}
return handlerTypeDynamic
}

View File

@@ -1,156 +0,0 @@
package client
import (
"context"
"fmt"
"net/netip"
"sync"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
)
type benchmarkTier struct {
name string
peers int
routes int
haPeersPerGroup int
}
var benchmarkTiers = []benchmarkTier{
{"Small", 100, 50, 4},
{"Medium", 1000, 200, 16},
{"Large", 5000, 500, 32},
}
type mockRouteHandler struct {
network string
}
func (m *mockRouteHandler) String() string { return m.network }
func (m *mockRouteHandler) AddRoute(context.Context) error { return nil }
func (m *mockRouteHandler) RemoveRoute() error { return nil }
func (m *mockRouteHandler) AddAllowedIPs(string) error { return nil }
func (m *mockRouteHandler) RemoveAllowedIPs() error { return nil }
func generateBenchmarkData(tier benchmarkTier) (*peer.Status, map[route.ID]*route.Route) {
statusRecorder := peer.NewRecorder("test-mgm")
routes := make(map[route.ID]*route.Route)
peerKeys := make([]string, tier.peers)
for i := 0; i < tier.peers; i++ {
peerKey := fmt.Sprintf("peer-%d", i)
peerKeys[i] = peerKey
fqdn := fmt.Sprintf("peer-%d.example.com", i)
ip := fmt.Sprintf("10.0.%d.%d", i/256, i%256)
err := statusRecorder.AddPeer(peerKey, fqdn, ip)
if err != nil {
panic(fmt.Sprintf("failed to add peer: %v", err))
}
var status peer.ConnStatus
var latency time.Duration
relayed := false
switch i % 10 {
case 0, 1: // 20% disconnected
status = peer.StatusConnecting
latency = 0
case 2: // 10% idle
status = peer.StatusIdle
latency = 50 * time.Millisecond
case 3, 4: // 20% relayed
status = peer.StatusConnected
relayed = true
latency = time.Duration(50+i%100) * time.Millisecond
default: // 50% direct connection
status = peer.StatusConnected
latency = time.Duration(10+i%40) * time.Millisecond
}
// Update peer state
state := peer.State{
PubKey: peerKey,
IP: ip,
FQDN: fqdn,
ConnStatus: status,
ConnStatusUpdate: time.Now(),
Relayed: relayed,
Latency: latency,
Mux: &sync.RWMutex{},
}
err = statusRecorder.UpdatePeerState(state)
if err != nil {
panic(fmt.Sprintf("failed to update peer state: %v", err))
}
}
routeID := 0
for i := 0; i < tier.routes; i++ {
network := fmt.Sprintf("192.168.%d.0/24", i%256)
prefix := netip.MustParsePrefix(network)
haGroupSize := 1
if i%4 == 0 { // 25% of routes have HA
haGroupSize = tier.haPeersPerGroup
}
for j := 0; j < haGroupSize; j++ {
peerIndex := (i*tier.haPeersPerGroup + j) % tier.peers
peerKey := peerKeys[peerIndex]
rID := route.ID(fmt.Sprintf("route-%d-%d", i, j))
metric := 100 + j*10
routes[rID] = &route.Route{
ID: rID,
Network: prefix,
Peer: peerKey,
Metric: metric,
NetID: route.NetID(fmt.Sprintf("net-%d", i)),
}
routeID++
}
}
return statusRecorder, routes
}
// Benchmark the optimized recalculate routes
func BenchmarkRecalculateRoutes(b *testing.B) {
for _, tier := range benchmarkTiers {
b.Run(tier.name, func(b *testing.B) {
statusRecorder, routes := generateBenchmarkData(tier)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
watcher := &Watcher{
ctx: ctx,
statusRecorder: statusRecorder,
routes: routes,
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan RoutesUpdate),
peerStateUpdate: make(chan map[string]peer.RouterState),
handler: &mockRouteHandler{network: "benchmark"},
currentChosenStatus: nil,
}
b.ResetTimer()
b.ReportAllocs()
routePeerStatuses := watcher.getRouterPeerStatuses()
for i := 0; i < b.N; i++ {
err := watcher.recalculateRoutes(reasonPeerUpdate, routePeerStatuses)
if err != nil {
b.Fatalf("recalculateRoutes failed: %v", err)
}
}
})
}
}

View File

@@ -1,827 +0,0 @@
package client
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
func TestGetBestrouteFromStatuses(t *testing.T) {
testCases := []struct {
name string
statuses map[route.ID]routerPeerStatus
expectedRouteID route.ID
currentRoute route.ID
existingRoutes map[route.ID]*route.Route
}{
{
name: "one route",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: false,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "one connected routes with relayed and direct",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: true,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "one connected routes with relayed and no direct",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: true,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "no connected peers",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnecting,
relayed: false,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
expectedRouteID: "",
},
{
name: "multiple connected peers with different metrics",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: false,
},
"route2": {
status: peer.StatusConnected,
relayed: false,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: 9000,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "multiple connected peers with one relayed",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: false,
},
"route2": {
status: peer.StatusConnected,
relayed: true,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "multiple connected peers with different latencies",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
latency: 300 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "should ignore routes with latency 0",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
latency: 0 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: false,
latency: 15 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: false,
latency: 200 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route2",
},
{
name: "current chosen route doesn't exist anymore",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnected,
relayed: false,
latency: 20 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "routeDoesntExistAnymore",
expectedRouteID: "route2",
},
{
name: "connected peer should be preferred over idle peer",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 10 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: false,
latency: 100 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "idle peer should be selected when no connected peers",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 10 * time.Millisecond,
},
"route2": {
status: peer.StatusConnecting,
relayed: false,
latency: 5 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "best idle peer should be selected among multiple idle peers",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 100 * time.Millisecond,
},
"route2": {
status: peer.StatusIdle,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "connecting peers should not be considered for routing",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnecting,
relayed: false,
latency: 10 * time.Millisecond,
},
"route2": {
status: peer.StatusConnecting,
relayed: false,
latency: 5 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "",
},
{
name: "mixed statuses - connected wins over idle and connecting",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusConnecting,
relayed: false,
latency: 5 * time.Millisecond,
},
"route2": {
status: peer.StatusIdle,
relayed: false,
latency: 10 * time.Millisecond,
},
"route3": {
status: peer.StatusConnected,
relayed: true,
latency: 200 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
"route3": {
ID: "route3",
Metric: route.MaxMetric,
Peer: "peer3",
},
},
currentRoute: "",
expectedRouteID: "route3",
},
{
name: "idle peer with better metric should win over idle peer with worse metric",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 50 * time.Millisecond,
},
"route2": {
status: peer.StatusIdle,
relayed: false,
latency: 50 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: 5000,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "current idle route should be maintained for similar scores",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 20 * time.Millisecond,
},
"route2": {
status: peer.StatusIdle,
relayed: false,
latency: 15 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "idle peer with zero latency should still be considered",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
status: peer.StatusConnecting,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "direct idle peer preferred over relayed idle peer",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: true,
latency: 10 * time.Millisecond,
},
"route2": {
status: peer.StatusIdle,
relayed: false,
latency: 50 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "connected peer with worse metric still beats idle peer with better metric",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 10 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: false,
latency: 50 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: 1000,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "connected peer wins even when idle peer has all advantages",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 1 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: true,
latency: 30 * time.Minute,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: 1,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "connected peer should be preferred over idle peer",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 10 * time.Millisecond,
},
"route2": {
status: peer.StatusConnected,
relayed: false,
latency: 100 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "idle peer should be selected when no connected peers",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 10 * time.Millisecond,
},
"route2": {
status: peer.StatusConnecting,
relayed: false,
latency: 5 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "best idle peer should be selected among multiple idle peers",
statuses: map[route.ID]routerPeerStatus{
"route1": {
status: peer.StatusIdle,
relayed: false,
latency: 100 * time.Millisecond,
},
"route2": {
status: peer.StatusIdle,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
}
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
status: peer.StatusConnecting,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
status: peer.StatusConnecting,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
ID: "routeDoesntExistAnymore",
}
if tc.currentRoute != "" {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
// create new clientNetwork
client := &Watcher{
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes,
currentChosen: currentRoute,
}
chosenRoute, _ := client.getBestRouteFromStatuses(tc.statuses)
if chosenRoute != tc.expectedRouteID {
t.Errorf("expected routeID %s, got %s", tc.expectedRouteID, chosenRoute)
}
})
}
}

View File

@@ -0,0 +1,410 @@
package routemanager
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
func TestGetBestrouteFromStatuses(t *testing.T) {
testCases := []struct {
name string
statuses map[route.ID]routerPeerStatus
expectedRouteID route.ID
currentRoute route.ID
existingRoutes map[route.ID]*route.Route
}{
{
name: "one route",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "one connected routes with relayed and direct",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "one connected routes with relayed and no direct",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "no connected peers",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: false,
relayed: false,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
expectedRouteID: "",
},
{
name: "multiple connected peers with different metrics",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
},
"route2": {
connected: true,
relayed: false,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: 9000,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "multiple connected peers with one relayed",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
},
"route2": {
connected: true,
relayed: true,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "multiple connected peers with different latencies",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
latency: 300 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "should ignore routes with latency 0",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 15 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 200 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route2",
},
{
name: "current chosen route doesn't exist anymore",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 20 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "routeDoesntExistAnymore",
expectedRouteID: "route2",
},
}
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
ID: "routeDoesntExistAnymore",
}
if tc.currentRoute != "" {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
// create new clientNetwork
client := &clientNetwork{
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes,
currentChosen: currentRoute,
}
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
if chosenRoute != tc.expectedRouteID {
t.Errorf("expected routeID %s, got %s", tc.expectedRouteID, chosenRoute)
}
})
}
}

View File

@@ -6,7 +6,6 @@ import (
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
@@ -64,118 +63,6 @@ func (d *DnsInterceptor) AddRoute(context.Context) error {
return nil
}
// preResolveDomains performs background DNS resolution for non-wildcard domains
func (d *DnsInterceptor) preResolveDomains() {
for _, domain := range d.route.Domains {
domainStr := string(domain)
if strings.HasPrefix(domainStr, "*.") {
continue
}
domainStr = strings.TrimSuffix(domainStr, ".")
go func(domain string) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := d.resolveAndUpdateDomain(ctx, domain); err != nil {
log.Debugf("pre-resolve failed for domain %s: %v", domain, err)
} else {
log.Tracef("pre-resolve completed for domain %s", domain)
}
}(domainStr)
}
}
// resolveAndUpdateDomain performs DNS resolution and updates domain prefixes
func (d *DnsInterceptor) resolveAndUpdateDomain(ctx context.Context, qDomain string) error {
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
if peerKey == "" {
return fmt.Errorf("no current peer key")
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
return fmt.Errorf("get upstream IP: %v", err)
}
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(qDomain), dns.TypeA)
msg.Id = dns.Id()
msg.MsgHdr.AuthenticatedData = true
reply, err := d.exchangeWithUpstream(ctx, msg, upstreamIP)
if err != nil {
return fmt.Errorf("exchange with upstream: %v", err)
}
if reply == nil || len(reply.Answer) == 0 {
return nil
}
resolvedDomain := domain.Domain(dns.Fqdn(qDomain))
return d.processResolveResponse(reply, resolvedDomain, resolvedDomain)
}
// exchangeWithUpstream performs DNS exchange with the upstream server
func (d *DnsInterceptor) exchangeWithUpstream(ctx context.Context, msg *dns.Msg, upstreamIP netip.Addr) (*dns.Msg, error) {
client := &dns.Client{
Timeout: nbdns.UpstreamTimeout,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, msg, upstream)
return reply, err
}
// extractIPsFromDNSResponse extracts IP addresses from DNS answer records
func (d *DnsInterceptor) extractIPsFromDNSResponse(reply *dns.Msg, domainForLogging domain.Domain) []netip.Prefix {
if reply == nil || len(reply.Answer) == 0 {
return nil
}
var prefixes []netip.Prefix
for _, answer := range reply.Answer {
var ip netip.Addr
switch rr := answer.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", domainForLogging, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", domainForLogging, rr.AAAA)
continue
}
ip = addr
default:
continue
}
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
prefixes = append(prefixes, prefix)
}
return prefixes
}
// processResolveResponse extracts IPs from DNS response and updates domain prefixes
func (d *DnsInterceptor) processResolveResponse(reply *dns.Msg, resolvedDomain, originalDomain domain.Domain) error {
newPrefixes := d.extractIPsFromDNSResponse(reply, resolvedDomain)
if len(newPrefixes) > 0 {
return d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes)
}
return nil
}
func (d *DnsInterceptor) RemoveRoute() error {
d.mu.Lock()
@@ -226,7 +113,6 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
}
d.currentPeerKey = peerKey
go d.preResolveDomains()
return nberrors.FormatErrorOrNil(merr)
}
@@ -279,8 +165,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
reply, err := d.exchangeWithUpstream(context.TODO(), r, upstreamIP)
client := &dns.Client{
Timeout: nbdns.UpstreamTimeout,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil {
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
@@ -345,13 +235,43 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
// already punycode via RegisterHandler()
originalDomain := domain.Domain(origPattern)
if originalDomain == "" {
originalDomain = resolvedDomain
}
if err := d.processResolveResponse(r, resolvedDomain, originalDomain); err != nil {
log.Errorf("failed to process DNS response: %v", err)
var newPrefixes []netip.Prefix
for _, answer := range r.Answer {
var ip netip.Addr
switch rr := answer.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
default:
continue
}
prefix := netip.PrefixFrom(ip, ip.BitLen())
newPrefixes = append(newPrefixes, prefix)
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
}
}

View File

@@ -2,15 +2,14 @@ package iface
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgIfaceBase interface {
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Name() string
Address() wgaddr.Address

View File

@@ -11,11 +11,9 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
@@ -23,11 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/server"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
@@ -41,8 +37,7 @@ import (
// Manager is a route manager interface
type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
@@ -73,9 +68,9 @@ type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[route.HAUniqueID]*client.Watcher
clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter *server.Router
serverRouter *serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
relayMgr *relayClient.Manager
@@ -93,7 +88,6 @@ type DefaultManager struct {
useNewDNSRoute bool
disableClientRoutes bool
disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler
}
func NewManager(config ManagerConfig) *DefaultManager {
@@ -105,7 +99,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
ctx: mCTX,
stop: cancel,
dnsRouteInterval: config.DNSRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*client.Watcher),
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
relayMgr: config.RelayManager,
sysOps: sysOps,
statusRecorder: config.StatusRecorder,
@@ -117,7 +111,6 @@ func NewManager(config ManagerConfig) *DefaultManager {
peerStore: config.PeerStore,
disableClientRoutes: config.DisableClientRoutes,
disableServerRoutes: config.DisableServerRoutes,
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
}
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
@@ -159,10 +152,10 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
m.allowedIPsRefCounter = refcounter.New(
func(prefix netip.Prefix, peerKey string) (string, error) {
// save peerKey to use it in the remove function
return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix)
return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix.String())
},
func(prefix netip.Prefix, peerKey string) error {
if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix); err != nil {
if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
return err
}
@@ -233,7 +226,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
}
var err error
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil {
return err
}
@@ -244,7 +237,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop()
if m.serverRouter != nil {
m.serverRouter.CleanUp()
m.serverRouter.cleanUp()
}
if m.routeRefCounter != nil {
@@ -272,60 +265,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
toAdd := make(map[route.HAUniqueID]*route.Route)
toRemove := make(map[route.HAUniqueID]client.RouteHandler)
for id, routes := range newRoutes {
if len(routes) > 0 {
toAdd[id] = routes[0]
}
}
for id, activeHandler := range m.activeRoutes {
if _, exists := toAdd[id]; exists {
delete(toAdd, id)
} else {
toRemove[id] = activeHandler
}
}
var merr *multierror.Error
for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
}
delete(m.activeRoutes, id)
}
for id, route := range toAdd {
handler := client.HandlerFromRoute(
route,
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsRouteInterval,
m.statusRecorder,
m.wgInterface,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
if err := handler.AddRoute(m.ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
continue
}
m.activeRoutes[id] = handler
}
return nberrors.FormatErrorOrNil(merr)
}
func (m *DefaultManager) UpdateRoutes(
updateSerial uint64,
serverRoutes map[route.ID]*route.Route,
clientRoutes route.HAMap,
useNewDNSRoute bool,
) error {
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
@@ -337,28 +277,24 @@ func (m *DefaultManager) UpdateRoutes(
defer m.mux.Unlock()
m.useNewDNSRoute = useNewDNSRoute
var merr *multierror.Error
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
if !m.disableClientRoutes {
filteredClientRoutes := m.routeSelector.FilterSelected(clientRoutes)
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
}
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes)
}
m.clientRoutes = clientRoutes
m.clientRoutes = newClientRoutesIDMap
if m.serverRouter == nil {
return nberrors.FormatErrorOrNil(merr)
return nil
}
if err := m.serverRouter.UpdateRoutes(serverRoutes, useNewDNSRoute); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err))
if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
return fmt.Errorf("update routes: %w", err)
}
return nberrors.FormatErrorOrNil(merr)
return nil
}
// SetRouteChangeListener set RouteListener for route change Notifier
@@ -405,10 +341,6 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.notifier.OnNewRoutes(networks)
if err := m.updateSystemRoutes(networks); err != nil {
log.Errorf("failed to update system routes during selection: %v", err)
}
m.stopObsoleteClients(networks)
for id, routes := range networks {
@@ -417,24 +349,21 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue
}
handler := m.activeRoutes[id]
if handler == nil {
log.Warnf("no active handler found for route %s", id)
continue
}
config := client.WatcherConfig{
Context: m.ctx,
DNSRouteInterval: m.dnsRouteInterval,
WGInterface: m.wgInterface,
StatusRecorder: m.statusRecorder,
Route: routes[0],
Handler: handler,
}
clientNetworkWatcher := client.NewWatcher(config)
clientNetworkWatcher := newClientNetworkWatcher(
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
@@ -446,7 +375,8 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok {
client.Stop()
log.Debugf("Stopping client network watcher, %s", id)
client.cancel()
delete(m.clientNetworks, id)
}
}
@@ -459,33 +389,30 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
handler := m.activeRoutes[id]
if handler == nil {
log.Errorf("No active handler found for route %s", id)
continue
}
config := client.WatcherConfig{
Context: m.ctx,
DNSRouteInterval: m.dnsRouteInterval,
WGInterface: m.wgInterface,
StatusRecorder: m.statusRecorder,
Route: routes[0],
Handler: handler,
}
clientNetworkWatcher = client.NewWatcher(config)
clientNetworkWatcher = newClientNetworkWatcher(
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := client.RoutesUpdate{
UpdateSerial: updateSerial,
Routes: routes,
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.SendUpdate(update)
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}
func (m *DefaultManager) ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
newClientRoutesIDMap := make(route.HAMap)
newServerRoutesMap := make(map[route.ID]*route.Route)
ownNetworkIDs := make(map[route.HAUniqueID]bool)
@@ -512,7 +439,7 @@ func (m *DefaultManager) ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]
}
func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.ClassifyRoutes(initialRoutes)
_, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0, len(crMap))
for _, routes := range crMap {
rs = append(rs, routes...)

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/netip"
"runtime"
"testing"
"github.com/pion/transport/v3/stdnet"
@@ -44,7 +45,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -71,7 +72,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.252.248/30"),
Network: netip.MustParsePrefix("100.64.252.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -99,7 +100,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.248/30"),
Network: netip.MustParsePrefix("100.64.30.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -127,7 +128,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.248/30"),
Network: netip.MustParsePrefix("100.64.30.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -211,7 +212,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -233,7 +234,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -250,7 +251,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -272,7 +273,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -282,7 +283,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "b",
NetID: "routeA",
Peer: remotePeerKey2,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -299,7 +300,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -327,7 +328,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -356,7 +357,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "l1",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -376,7 +377,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "r1",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.248/30"),
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@@ -439,14 +440,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
routeManager.serverRouter = nil
}
serverRoutes, clientRoutes := routeManager.ClassifyRoutes(testCase.inputRoutes)
if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, serverRoutes, clientRoutes, false)
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
require.NoError(t, err, "should update routes with init routes")
}
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), serverRoutes, clientRoutes, false)
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected
@@ -455,8 +454,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
}
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
if routeManager.serverRouter != nil {
require.Equal(t, testCase.serverRoutesExpected, routeManager.serverRouter.RoutesCount(), "server networks size should match")
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
}
})
}

View File

@@ -14,8 +14,7 @@ import (
// MockManager is the mock instance of a route manager
type MockManager struct {
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
@@ -33,21 +32,13 @@ func (m *MockManager) InitialRouteRange() []string {
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error {
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error {
if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes, clientRoutes, useNewDNSRoute)
return m.UpdateRoutesFunc(updateSerial, newRoutes)
}
return nil
}
// ClassifyRoutes mock implementation of ClassifyRoutes from Manager interface
func (m *MockManager) ClassifyRoutes(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
if m.ClassifyRoutesFunc != nil {
return m.ClassifyRoutesFunc(routes)
}
return nil, nil
}
func (m *MockManager) TriggerSelection(networks route.HAMap) {
if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks)

View File

@@ -1,4 +1,4 @@
package server
package routemanager
import (
"context"
@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/route"
)
type Router struct {
type serverRouter struct {
mux sync.Mutex
ctx context.Context
routes map[route.ID]*route.Route
@@ -23,8 +23,8 @@ type Router struct {
statusRecorder *peer.Status
}
func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*Router, error) {
return &Router{
func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
return &serverRouter{
ctx: ctx,
routes: make(map[route.ID]*route.Route),
firewall: firewall,
@@ -33,110 +33,104 @@ func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall
}, nil
}
func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
r.mux.Lock()
defer r.mux.Unlock()
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
m.mux.Lock()
defer m.mux.Unlock()
serverRoutesToRemove := make([]route.ID, 0)
for routeID := range r.routes {
for routeID := range m.routes {
update, found := routesMap[routeID]
if !found || !update.Equal(r.routes[routeID]) {
if !found || !update.Equal(m.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := r.routes[routeID]
err := r.removeFromServerNetwork(oldRoute)
oldRoute := m.routes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(r.routes, routeID)
delete(m.routes, routeID)
}
// If routing is to be disabled, do it after routes have been removed
// If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled
if len(routesMap) > 0 {
if err := r.firewall.EnableRouting(); err != nil {
if err := m.firewall.EnableRouting(); err != nil {
return fmt.Errorf("enable routing: %w", err)
}
} else {
if err := r.firewall.DisableRouting(); err != nil {
if err := m.firewall.DisableRouting(); err != nil {
return fmt.Errorf("disable routing: %w", err)
}
}
for id, newRoute := range routesMap {
_, found := r.routes[id]
_, found := m.routes[id]
if found {
continue
}
err := r.addToServerNetwork(newRoute, useNewDNSRoute)
err := m.addToServerNetwork(newRoute, useNewDNSRoute)
if err != nil {
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
r.routes[id] = newRoute
m.routes[id] = newRoute
}
return nil
}
func (r *Router) removeFromServerNetwork(route *route.Route) error {
if r.ctx.Err() != nil {
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
if m.ctx.Err() != nil {
log.Infof("Not removing from server network because context is done")
return r.ctx.Err()
return m.ctx.Err()
}
routerPair := routeToRouterPair(route, false)
if err := r.firewall.RemoveNatRule(routerPair); err != nil {
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
delete(r.routes, route.ID)
r.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
delete(m.routes, route.ID)
m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
return nil
}
func (r *Router) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
if r.ctx.Err() != nil {
func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
if m.ctx.Err() != nil {
log.Infof("Not adding to server network because context is done")
return r.ctx.Err()
return m.ctx.Err()
}
routerPair := routeToRouterPair(route, useNewDNSRoute)
if err := r.firewall.AddNatRule(routerPair); err != nil {
if err := m.firewall.AddNatRule(routerPair); err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
r.routes[route.ID] = route
r.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
m.routes[route.ID] = route
m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
return nil
}
func (r *Router) CleanUp() {
r.mux.Lock()
defer r.mux.Unlock()
func (m *serverRouter) cleanUp() {
m.mux.Lock()
defer m.mux.Unlock()
for _, route := range r.routes {
routerPair := routeToRouterPair(route, false)
if err := r.firewall.RemoveNatRule(routerPair); err != nil {
for _, r := range m.routes {
routerPair := routeToRouterPair(r, false)
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
log.Errorf("Failed to remove cleanup route: %v", err)
}
}
r.statusRecorder.CleanLocalPeerStateRoutes()
}
func (r *Router) RoutesCount() int {
r.mux.Lock()
defer r.mux.Unlock()
return len(r.routes)
m.statusRecorder.CleanLocalPeerStateRoutes()
}
func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair {

View File

@@ -24,22 +24,19 @@ func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allo
}
}
// Route route methods
func (r *Route) String() string {
return r.route.Network.String()
}
func (r *Route) AddRoute(context.Context) error {
if _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}); err != nil {
return err
}
return nil
_, err := r.routeRefCounter.Increment(r.route.Network, struct{}{})
return err
}
func (r *Route) RemoveRoute() error {
if _, err := r.routeRefCounter.Decrement(r.route.Network); err != nil {
return err
}
return nil
_, err := r.routeRefCounter.Decrement(r.route.Network)
return err
}
func (r *Route) AddAllowedIPs(peerKey string) error {
@@ -55,8 +52,6 @@ func (r *Route) AddAllowedIPs(peerKey string) error {
}
func (r *Route) RemoveAllowedIPs() error {
if _, err := r.allowedIPsRefcounter.Decrement(r.route.Network); err != nil {
return err
}
return nil
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
return err
}

View File

@@ -13,7 +13,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
)
const (
@@ -22,13 +22,8 @@ const (
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
type iface interface {
Address() wgaddr.Address
Name() string
}
// Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface iface) (map[string]int, error) {
func Setup(wgIface iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error

View File

@@ -6,10 +6,9 @@ import (
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type Nexthop struct {
@@ -31,16 +30,11 @@ func (n Nexthop) String() string {
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
}
type wgIface interface {
Address() wgaddr.Address
Name() string
}
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface wgIface
wgInterface iface.WGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
@@ -51,27 +45,9 @@ type SysOps struct {
notifier *notifier.Notifier
}
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
}
}
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr := prefix.Addr()
switch {
case
!addr.IsValid(),
addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsMulticast(),
addr.IsUnspecified() && prefix.Bits() != 0,
r.wgInterface.Address().Network.Contains(addr):
return vars.ErrRouteNotAllowed
}
return nil
}

View File

@@ -8,8 +8,6 @@ import (
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
@@ -35,12 +33,7 @@ func init() {
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil, nil)
@@ -50,7 +43,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, nexthop); err != nil {
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
@@ -66,7 +59,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
@@ -126,39 +119,18 @@ func TestBits(t *testing.T) {
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
return "lo0"
}
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
t.Helper()
var originalNexthop net.IP
@@ -204,40 +176,12 @@ func fetchOriginalGateway() (net.IP, error) {
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
}

View File

@@ -17,6 +17,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
@@ -105,15 +106,59 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) {
if err := r.validateRoute(prefix); err != nil {
return Nexthop{}, err
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
if addr.IsUnspecified() {
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return Nexthop{}, vars.ErrRouteNotAllowed
}
@@ -134,7 +179,10 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
Intf: nexthop.Intf,
}
vpnAddr := vpnIntf.Address().IP
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
@@ -223,7 +271,32 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
return nil
}
return r.addToRouteTable(prefix, nextHop)
return r.addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
@@ -335,9 +408,13 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
@@ -380,6 +457,32 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()

View File

@@ -3,25 +3,23 @@
package systemops
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os/exec"
"os"
"runtime"
"strconv"
"strings"
"syscall"
"testing"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type dialer interface {
@@ -29,370 +27,105 @@ type dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddVPNRoute(t *testing.T) {
func TestAddRemoveRoutes(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
expectError bool
name string
prefix netip.Prefix
shouldRouteToWireguard bool
shouldBeRemoved bool
}{
{
name: "IPv4 - Private network route",
prefix: netip.MustParsePrefix("10.10.100.0/24"),
name: "Should Add And Remove Route 100.66.120.0/24",
prefix: netip.MustParsePrefix("100.66.120.0/24"),
shouldRouteToWireguard: true,
shouldBeRemoved: true,
},
{
name: "IPv4 Single host",
prefix: netip.MustParsePrefix("10.111.111.111/32"),
},
{
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Default route",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"),
},
{
name: "IPv6 Default route",
prefix: netip.MustParsePrefix("::/0"),
},
// IPv4 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local multicast",
prefix: netip.MustParsePrefix("224.0.0.251/32"),
expectError: true,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
},
{
name: "IPv4 Unspecified with prefix",
prefix: netip.MustParsePrefix("0.0.0.0/32"),
expectError: true,
},
// IPv6 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local multicast",
prefix: netip.MustParsePrefix("ff02::1/128"),
expectError: true,
},
{
name: "IPv6 Interface-local multicast",
prefix: netip.MustParsePrefix("ff01::1/128"),
expectError: true,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
},
{
name: "IPv6 Unspecified with prefix",
prefix: netip.MustParsePrefix("::/128"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network subnet",
prefix: netip.MustParsePrefix("100.65.75.0/32"),
expectError: true,
name: "Should Not Add Or Remove Route 127.0.0.1/32",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
shouldRouteToWireguard: false,
shouldBeRemoved: false,
},
}
for n, testCase := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", testCase.name, " on freebsd")
}
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
_, _, err = r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
})
intf, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err)
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// add the route
err = r.AddVPNRoute(testCase.prefix, intf)
if testCase.expectError {
assert.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
}
require.NoError(t, err, "genericAddVPNRoute should not return err")
// validate it's pointing to the WireGuard interface
require.NoError(t, err)
nextHop := getNextHop(t, testCase.prefix.Addr())
assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface")
// remove route again
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err)
// validate it's gone
nextHop, err = GetNextHop(testCase.prefix.Addr())
require.True(t,
errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(),
"err: %v, next hop: %v", err, nextHop)
})
}
}
func getNextHop(t *testing.T, addr netip.Addr) Nexthop {
t.Helper()
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
nextHop, err := GetNextHop(addr)
if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() {
// TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is
// present in the route table.
t.Skip("Skipping windows test")
}
require.NoError(t, err)
require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr)
return nextHop
}
// GetNextHop for bsd is buggy and returns the wrong interface for the default route.
if addr.IsUnspecified() {
// On macOS, querying 0.0.0.0 returns the wrong interface
if addr.Is4() {
addr = netip.MustParseAddr("1.2.3.4")
} else {
addr = netip.MustParseAddr("2001:db8::1")
}
}
cmd := exec.Command("route", "-n", "get", addr.String())
if addr.Is6() {
cmd = exec.Command("route", "-n", "get", "-inet6", addr.String())
}
output, err := cmd.CombinedOutput()
t.Logf("route output: %s", output)
require.NoError(t, err, "%s failed")
lines := strings.Split(string(output), "\n")
var intf string
var gateway string
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "interface:") {
intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
} else if strings.HasPrefix(line, "gateway:") {
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
}
}
require.NotEmpty(t, intf, "interface should be found in route output")
iface, err := net.InterfaceByName(intf)
require.NoError(t, err, "interface %s should exist", intf)
nexthop := Nexthop{Intf: iface}
if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) {
addr, err := netip.ParseAddr(gateway)
if err == nil {
nexthop.IP = addr
}
}
return nexthop
}
func TestAddRouteToNonVPNIntf(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
expectError bool
errorType error
}{
{
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Single host",
prefix: netip.MustParsePrefix("8.8.8.8/32"),
},
{
name: "IPv6 External network route",
prefix: netip.MustParsePrefix("2001:db8:1000::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2001:db8::1/128"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"),
},
// Addresses that should be rejected
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Unspecified",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Unspecified",
prefix: netip.MustParsePrefix("::/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
require.NoError(t, err, "Should be able to get IPv4 default route")
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if testCase.prefix.Addr().Is6() &&
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
t.Skip("Skipping test as no ipv6 default route is available")
}
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
t.Fatalf("Failed to get IPv6 default route: %v", err)
}
var initialNextHop Nexthop
if testCase.prefix.Addr().Is6() {
initialNextHop = initialNextHopV6
if testCase.shouldRouteToWireguard {
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
} else {
initialNextHop = initialNextHopV4
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
}
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop)
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
if testCase.expectError {
require.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err)
if testCase.shouldBeRemoved {
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
}
}
require.NoError(t, err)
t.Logf("Next hop for %s: %s", testCase.prefix, nexthop)
// Verify the route was added and points to non-VPN interface
currentNextHop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err)
assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface")
err = r.removeFromRouteTable(testCase.prefix, nexthop)
assert.NoError(t, err)
})
}
}
func TestGetNextHop(t *testing.T) {
defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if runtime.GOOS == "freebsd" {
t.Skip("skipping on freebsd")
}
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if !defaultNh.IP.IsValid() {
if !nexthop.IP.IsValid() {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
@@ -400,6 +133,7 @@ func TestGetNextHop(t *testing.T) {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var testingIP string
var testingPrefix netip.Prefix
for _, address := range addresses {
if address.Network() != "ip+net" {
@@ -407,23 +141,213 @@ func TestGetNextHop(t *testing.T) {
}
prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked()
break
}
}
nh, err := GetNextHop(testingPrefix.Addr())
localIP, err := GetNextHop(testingPrefix.Addr())
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if nh.Intf == nil {
if !localIP.IP.IsValid() {
t.Fatal("should return a gateway for local network")
}
if nh.IP.String() == defaultNh.IP.String() {
t.Fatal("next hop IP should not match with default gateway IP")
if localIP.IP.String() == nexthop.IP.String() {
t.Fatal("local IP should not match with gateway IP")
}
if nh.Intf.Name != defaultNh.Intf.Name {
t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name)
if localIP.IP.String() != testingIP {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
testCases := []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Add And Remove random Route",
prefix: netip.MustParsePrefix("99.99.99.99/32"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
shouldAddRoute: false,
},
{
name: "Should Add Route if bigger network exists",
prefix: netip.MustParsePrefix("100.100.100.0/24"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPort: 33100,
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface, nil)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
// test if route exists after adding
ok, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "should not return err")
require.True(t, ok, "route should exist")
// remove route again if added
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
// route should either not have been added or should have been removed
// In case of already existing route, it should not have been added (but still exist)
ok, err := existsInRouteTable(testCase.prefix)
t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
switch {
case p.Addr().Is6():
continue
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
continue
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
continue
// FreeBSD loopback 127/8 is not added to the routing table
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
continue
default:
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}
@@ -460,16 +384,11 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper()
if err := r.AddVPNRoute(prefix, intf); err != nil {
if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) {
t.Fatalf("addVPNRoute should not return err: %v", err)
}
t.Logf("addVPNRoute %v returned: %v", prefix, err)
}
err := r.AddVPNRoute(prefix, intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) {
t.Fatalf("removeVPNRoute should not return err: %v", err)
}
err = r.RemoveVPNRoute(prefix, intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
@@ -503,10 +422,28 @@ func setupTestEnv(t *testing.T) {
// 10.10.0.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
// 127.0.10.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
// unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixNexthop, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
}
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string

View File

@@ -149,10 +149,6 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() {
return r.genericAddVPNRoute(prefix, intf)
}
@@ -176,10 +172,6 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() {
return r.genericRemoveVPNRoute(prefix, intf)
}
@@ -227,7 +219,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr.Unmap(), ones)
prefix := netip.PrefixFrom(addr, ones)
if prefix.IsValid() {
prefixList = append(prefixList, prefix)
}
@@ -255,7 +247,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
return fmt.Errorf("netlink add route: %w", err)
}
@@ -278,7 +270,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet,
}
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
return fmt.Errorf("netlink add unreachable route: %w", err)
}

View File

@@ -19,6 +19,7 @@ import (
)
var expectedVPNint = "wgtest0"
var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
@@ -30,6 +31,12 @@ func init() {
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
},
{
name: "To more specific route (local) without custom dialer via physical interface",
expectedInterface: expectedLoopbackInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}...)
}

View File

@@ -11,16 +11,10 @@ import (
)
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericAddVPNRoute(prefix, intf)
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericRemoveVPNRoute(prefix, intf)
}

View File

@@ -1,268 +0,0 @@
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type mockWGIface struct {
address wgaddr.Address
name string
}
func (m *mockWGIface) Address() wgaddr.Address {
return m.address
}
func (m *mockWGIface) Name() string {
return m.name
}
func TestSysOps_validateRoute(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
}{
// Valid routes
{
name: "valid IPv4 route",
prefix: "192.168.1.0/24",
expectError: false,
},
{
name: "valid IPv6 route",
prefix: "2001:db8::/32",
expectError: false,
},
{
name: "valid single IPv4 host",
prefix: "8.8.8.8/32",
expectError: false,
},
{
name: "valid single IPv6 host",
prefix: "2001:4860:4860::8888/128",
expectError: false,
},
// Invalid routes - loopback
{
name: "IPv4 loopback",
prefix: "127.0.0.1/32",
expectError: true,
},
{
name: "IPv6 loopback",
prefix: "::1/128",
expectError: true,
},
// Invalid routes - link-local unicast
{
name: "IPv4 link-local unicast",
prefix: "169.254.1.1/32",
expectError: true,
},
{
name: "IPv6 link-local unicast",
prefix: "fe80::1/128",
expectError: true,
},
// Invalid routes - multicast
{
name: "IPv4 multicast",
prefix: "224.0.0.1/32",
expectError: true,
},
{
name: "IPv6 multicast",
prefix: "ff02::1/128",
expectError: true,
},
// Invalid routes - link-local multicast
{
name: "IPv4 link-local multicast",
prefix: "224.0.0.0/24",
expectError: true,
},
{
name: "IPv6 link-local multicast",
prefix: "ff02::/16",
expectError: true,
},
// Invalid routes - interface-local multicast (IPv6 only)
{
name: "IPv6 interface-local multicast",
prefix: "ff01::1/128",
expectError: true,
},
// Invalid routes - overlaps with WG interface network
{
name: "overlaps with WG network - exact match",
prefix: "10.0.0.0/24",
expectError: true,
},
{
name: "overlaps with WG network - subset",
prefix: "10.0.0.1/32",
expectError: true,
},
{
name: "overlaps with WG network - host in range",
prefix: "10.0.0.100/32",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s", tt.prefix)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix)
}
})
}
}
func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) {
wgNetwork := netip.MustParsePrefix("192.168.100.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
description string
}{
{
name: "identical subnet",
prefix: "192.168.100.0/24",
expectError: true,
description: "exact same network as WG interface",
},
{
name: "broader subnet containing WG network",
prefix: "192.168.0.0/16",
expectError: false,
description: "broader network that contains WG network should be allowed",
},
{
name: "host within WG network",
prefix: "192.168.100.50/32",
expectError: true,
description: "specific host within WG network",
},
{
name: "subnet within WG network",
prefix: "192.168.100.128/25",
expectError: true,
description: "smaller subnet within WG network",
},
{
name: "adjacent subnet - same /23",
prefix: "192.168.101.0/24",
expectError: false,
description: "adjacent subnet, no overlap",
},
{
name: "adjacent subnet - different /16",
prefix: "192.167.100.0/24",
expectError: false,
description: "different network, no overlap",
},
{
name: "WG network broadcast address",
prefix: "192.168.100.255/32",
expectError: true,
description: "broadcast address of WG network",
},
{
name: "WG network first usable",
prefix: "192.168.100.1/32",
expectError: true,
description: "first usable address in WG network",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description)
}
})
}
}
func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wt0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
var invalidPrefix netip.Prefix
err := sysOps.validateRoute(invalidPrefix)
require.Error(t, err, "validateRoute() expected error for invalid prefix")
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix")
}

View File

@@ -3,19 +3,15 @@
package systemops
import (
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"syscall"
"os/exec"
"strings"
"time"
"unsafe"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -30,16 +26,48 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
return r.routeCmd("add", prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
return r.routeCmd("delete", prefix, nexthop)
}
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
if !prefix.IsValid() {
return fmt.Errorf("invalid prefix: %s", prefix)
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
inet := "-inet"
if prefix.Addr().Is6() {
inet = "-inet6"
}
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
args := []string{"-n", action, inet, network}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else if nexthop.Intf != nil {
args = append(args, "-interface", nexthop.Intf.Name)
}
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
}
expBackOff := backoff.NewExponentialBackOff()
@@ -47,157 +75,9 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
a := "add"
if action == unix.RTM_DELETE {
a = "remove"
}
return fmt.Errorf("%s route for %s: %w", a, prefix, err)
}
return nil
}
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
}
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
}
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Version: unix.RTM_VERSION,
Seq: 1,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr())
err := backoff.Retry(operation, expBackOff)
if err != nil {
return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err)
return fmt.Errorf("route cmd retry failed: %w", err)
}
if prefix.IsSingleIP() {
msg.Flags |= unix.RTF_HOST
} else {
addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix)
if err != nil {
return nil, fmt.Errorf("build netmask for %s: %w", prefix, err)
}
}
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err)
}
} else if nexthop.Intf != nil {
msg.Index = nexthop.Intf.Index
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {
return &route.Inet4Addr{IP: addr.As4()}, nil
}
if addr.Zone() == "" {
return &route.Inet6Addr{IP: addr.As16()}, nil
}
var zone int
// zone can be either a numeric zone ID or an interface name.
if z, err := strconv.Atoi(addr.Zone()); err == nil {
zone = z
} else {
iface, err := net.InterfaceByName(addr.Zone())
if err != nil {
return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err)
}
zone = iface.Index
}
return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil
}
func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
m := net.CIDRMask(bits, 32)
var maskBytes [4]byte
copy(maskBytes[:], m)
return &route.Inet4Addr{IP: maskBytes}, nil
}
if prefix.Addr().Is6() {
m := net.CIDRMask(bits, 128)
var maskBytes [16]byte
copy(maskBytes[:], m)
return &route.Inet6Addr{IP: maskBytes}, nil
}
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
}

Some files were not shown because too many files have changed in this diff Show More