Compare commits

..

4 Commits

Author SHA1 Message Date
Zoltán Papp
d9bcdcf149 Fix lint 2024-05-15 03:26:30 +02:00
Zoltán Papp
d39814f173 Fix lint 2024-05-15 02:26:29 +02:00
Zoltán Papp
4a2429eb1c Fix cleanup and error handling 2024-05-15 01:36:40 +02:00
Zoltán Papp
de2e6557ad Revert context changes in proxy implementations 2024-05-15 00:27:40 +02:00
311 changed files with 7225 additions and 13881 deletions

View File

@@ -14,7 +14,7 @@ jobs:
test: test:
strategy: strategy:
matrix: matrix:
store: ['sqlite'] store: ['jsonfile', 'sqlite']
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go

View File

@@ -1,39 +0,0 @@
name: Test Code FreeBSD
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
prepare: |
pkg install -y curl
pkg install -y git
run: |
set -x
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
tar zxf go.tar.gz
mv go /usr/local/go
ln -s /usr/local/go/bin/go /usr/local/bin/go
go mod tidy
go test -timeout 5m -p 1 ./iface/...
go test -timeout 5m -p 1 ./client/...
cd client
go build .
cd ..

View File

@@ -15,7 +15,7 @@ jobs:
strategy: strategy:
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'jsonfile', 'sqlite' ]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
@@ -86,10 +86,7 @@ jobs:
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin - name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin - name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
@@ -111,9 +108,6 @@ jobs:
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker - name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -173,7 +173,7 @@ jobs:
retention-days: 3 retention-days: 3
release_ui_darwin: release_ui_darwin:
runs-on: macos-latest runs-on: macos-11
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV

View File

@@ -178,79 +178,34 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL - name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres - name: test Caddy file gen
run: test -f Caddyfile run: test -f Caddyfile
- name: test docker-compose file gen
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml run: test -f docker-compose.yml
- name: test management.json file gen
- name: test management.json file gen postgres
run: test -f management.json run: test -f management.json
- name: test turnserver.conf file gen
- name: test turnserver.conf file gen postgres
run: | run: |
set -x set -x
test -f turnserver.conf test -f turnserver.conf
grep external-ip turnserver.conf grep external-ip turnserver.conf
- name: test zitadel.env file gen
- name: test zitadel.env file gen postgres
run: test -f zitadel.env run: test -f zitadel.env
- name: test dashboard.env file gen
- name: test dashboard.env file gen postgres
run: test -f dashboard.env run: test -f dashboard.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
test-download-geolite2-script: test-download-geolite2-script:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install jq - name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3 run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: test script - name: test script
run: bash -x infrastructure_files/download-geolite2.sh run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists - name: test mmdb file exists
run: test -f GeoLite2-City.mmdb run: test -f GeoLite2-City.mmdb
- name: test geonames file exists - name: test geonames file exists
run: test -f geonames.db run: test -f geonames.db

View File

@@ -130,10 +130,3 @@ issues:
- path: mock\.go - path: mock\.go
linters: linters:
- nilnil - nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -3,10 +3,8 @@ builds:
- id: netbird-ui-darwin - id: netbird-ui-darwin
dir: client/ui dir: client/ui
binary: netbird-ui binary: netbird-ui
env: env: [CGO_ENABLED=1]
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
goos: goos:
- darwin - darwin
goarch: goarch:

View File

@@ -5,7 +5,7 @@
We as members, contributors, and leaders pledge to make participation in our We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socioeconomic status, identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation. identity and orientation.

View File

@@ -1,4 +1,4 @@
FROM alpine:3.19 FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -57,17 +57,15 @@ type Client struct {
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{ return &Client{
cfgFile: cfgFile, cfgFile: cfgFile,
deviceName: deviceName, deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""), recorder: peer.NewRecorder(""),
@@ -90,9 +88,6 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
var ctx context.Context var ctx context.Context
//nolint //nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName) ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.UiVersionCtxKey, c.uiVersion)
c.ctxCancelLock.Lock() c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues) ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel() defer c.ctxCancel()

View File

@@ -3,14 +3,13 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
) )
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
@@ -59,7 +58,7 @@ var forCmd = &cobra.Command{
} }
func debugBundle(cmd *cobra.Command, _ []string) error { func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@@ -80,14 +79,14 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
} }
func setLogLevel(cmd *cobra.Command, args []string) error { func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() defer conn.Close()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0]) level := parseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN { if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0]) return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
} }
@@ -103,13 +102,34 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func parseLogLevel(level string) proto.LogLevel {
switch strings.ToLower(level) {
case "panic":
return proto.LogLevel_PANIC
case "fatal":
return proto.LogLevel_FATAL
case "error":
return proto.LogLevel_ERROR
case "warn":
return proto.LogLevel_WARN
case "info":
return proto.LogLevel_INFO
case "debug":
return proto.LogLevel_DEBUG
case "trace":
return proto.LogLevel_TRACE
default:
return proto.LogLevel_UNKNOWN
}
}
func runForDuration(cmd *cobra.Command, args []string) error { func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0]) duration, err := time.ParseDuration(args[0])
if err != nil { if err != nil {
return fmt.Errorf("invalid duration format: %v", err) return fmt.Errorf("invalid duration format: %v", err)
} }
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@@ -117,33 +137,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird down") cmd.Println("Netbird down")
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace {
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{ _, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE, Level: proto.LogLevel_TRACE,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to set log level to TRACE: %v", status.Convert(err).Message()) return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
} }
cmd.Println("Log level set to trace.") cmd.Println("Log level set to trace.")
}
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@@ -170,22 +175,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
} }
cmd.Println("Netbird down") cmd.Println("Netbird down")
// TODO reset log level
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
if restoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println("Creating debug bundle...") cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{

View File

@@ -2,9 +2,8 @@ package cmd
import ( import (
"context" "context"
"time"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"

View File

@@ -36,7 +36,6 @@ const (
disableAutoConnectFlag = "disable-auto-connect" disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
) )
var ( var (
@@ -69,8 +68,6 @@ var (
autoConnectDisabled bool autoConnectDisabled bool
extraIFaceBlackList []string extraIFaceBlackList []string
anonymizeFlag bool anonymizeFlag bool
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
Short: "", Short: "",
@@ -356,11 +353,8 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true return true
} }
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { func getClient(ctx context.Context) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd) conn, err := DialClientGRPCServer(ctx, daemonAddr)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+

View File

@@ -2,7 +2,6 @@ package cmd
import ( import (
"fmt" "fmt"
"strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@@ -50,7 +49,7 @@ func init() {
} }
func routesList(cmd *cobra.Command, _ []string) error { func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@@ -67,62 +66,20 @@ func routesList(cmd *cobra.Command, _ []string) error {
return nil return nil
} }
printRoutes(cmd, resp) cmd.Println("Available Routes:")
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
return nil return nil
} }
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error { func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }
@@ -149,7 +106,7 @@ func routesSelect(cmd *cobra.Command, args []string) error {
} }
func routesDeselect(cmd *cobra.Command, args []string) error { func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd.Context())
if err != nil { if err != nil {
return err return err
} }

View File

@@ -807,7 +807,11 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
} }
for i, route := range peer.Routes { for i, route := range peer.Routes {
peer.Routes[i] = anonymizeRoute(a, route) prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
} }
} }
@@ -843,21 +847,12 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
} }
for i, route := range overview.Routes { for i, route := range overview.Routes {
overview.Routes[i] = anonymizeRoute(a, route) prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
} }
overview.FQDN = a.AnonymizeDomain(overview.FQDN) overview.FQDN = a.AnonymizeDomain(overview.FQDN)
} }
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@@ -7,9 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -17,7 +14,6 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
@@ -56,10 +52,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
srv, err := sig.NewServer(otel.Meter("")) sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
go func() { go func() {
if err := s.Serve(lis); err != nil { if err := s.Serve(lis); err != nil {
panic(err) panic(err)
@@ -76,24 +69,23 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -108,7 +100,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
} }
func startClientDaemon( func startClientDaemon(
t *testing.T, ctx context.Context, _, configPath string, t *testing.T, ctx context.Context, managementURL, configPath string,
) (*grpc.Server, net.Listener) { ) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")

View File

@@ -7,13 +7,11 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -42,12 +40,8 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@@ -143,10 +137,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
} }
} }
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
@@ -247,10 +237,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.NetworkMonitor = &networkMonitor loginRequest.NetworkMonitor = &networkMonitor
} }
if cmd.Flag(dnsRouteIntervalFlag).Changed {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse

View File

@@ -1,30 +0,0 @@
package errors
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}
func FormatErrorOrNil(err *multierror.Error) error {
if err != nil {
err.ErrorFormat = formatError
}
return err.ErrorOrNil()
}

View File

@@ -42,20 +42,20 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Debug("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface) fm, errFw = nbiptables.Create(context, iface)
if errFw != nil { if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw) log.Errorf("failed to create iptables manager: %s", errFw)
} }
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Debug("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface) fm, errFw = nbnftables.Create(context, iface)
if errFw != nil { if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw) log.Errorf("failed to create nftables manager: %s", errFw)
} }
default: default:
errFw = fmt.Errorf("no firewall manager found") errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
} }
if iface.IsUserspaceBind() { if iface.IsUserspaceBind() {
@@ -85,58 +85,16 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType { func check() FWType {
useIPTABLES := false
var iptablesChains []string
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil && isIptablesClientAvailable(ip) {
major, minor, _ := ip.GetIptablesVersion()
// use iptables when its version is lower than 1.8.0 which doesn't work well with our nftables manager
if major < 1 || (major == 1 && minor < 8) {
return IPTABLES
}
useIPTABLES = true
iptablesChains, err = ip.ListChains("filter")
if err != nil {
log.Errorf("failed to list iptables chains: %s", err)
useIPTABLES = false
}
}
nf := nftables.Conn{} nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" { if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
if !useIPTABLES {
return NFTABLES return NFTABLES
} }
// search for chains where table is filter ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
// if we find one, we assume that nftables manager can be used with iptables if err != nil {
for _, chain := range chains { return UNKNOWN
if chain.Table.Name == "filter" {
return NFTABLES
} }
} if isIptablesClientAvailable(ip) {
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
}
if useIPTABLES {
return IPTABLES return IPTABLES
} }

View File

@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil return nil
} }
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil { if err != nil {
return err return err
} }
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil { if err != nil {
return err return err
} }
@@ -101,7 +101,6 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
} }
delete(i.rules, ruleKey) delete(i.rules, ruleKey)
} }
err = i.iptablesClient.Insert(table, chain, 1, rule...) err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil { if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
@@ -318,13 +317,6 @@ func (i *routerManager) createChain(table, newChain string) error {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
} }
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil { if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
@@ -334,30 +326,6 @@ func (i *routerManager) createChain(table, newChain string) error {
return nil return nil
} }
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification // genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string { func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump} return []string{"-s", source, "-d", destination, "-j", jump}

View File

@@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddRoutingRules(pair) return m.router.InsertRoutingRules(pair)
} }
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {

View File

@@ -22,8 +22,6 @@ const (
userDataAcceptForwardRuleSrc = "frwacceptsrc" userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst" userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
) )
// some presets for building nftable rules // some presets for building nftable rules
@@ -128,22 +126,6 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap() err := r.refreshRulesMap()
if err != nil { if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
@@ -156,28 +138,28 @@ func (r *router) createContainers() error {
return nil return nil
} }
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain // InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error { func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap() err := r.refreshRulesMap()
if err != nil { if err != nil {
return err return err
} }
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil { if err != nil {
return err return err
} }
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil { if err != nil {
return err return err
} }
if pair.Masquerade { if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil { if err != nil {
return err return err
} }
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil { if err != nil {
return err return err
} }
@@ -195,8 +177,8 @@ func (r *router) AddRoutingRules(pair manager.RouterPair) error {
return nil return nil
} }
// addRoutingRule inserts a nftable rule to the conn client flush queue // insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -217,7 +199,7 @@ func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPai
} }
} }
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainName], Chain: r.chains[chainName],
Exprs: expression, Exprs: expression,

View File

@@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.AddRoutingRules(testCase.InputPair) err = manager.InsertRoutingRules(testCase.InputPair)
defer func() { defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair) _ = manager.RemoveRoutingRules(testCase.InputPair)
}() }()

View File

@@ -6,16 +6,13 @@ import (
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
@@ -56,7 +53,6 @@ type ConfigInput struct {
NetworkMonitor *bool NetworkMonitor *bool
DisableAutoConnect *bool DisableAutoConnect *bool
ExtraIFaceBlackList []string ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
} }
// Config Configuration type // Config Configuration type
@@ -68,7 +64,7 @@ type Config struct {
AdminURL *url.URL AdminURL *url.URL
WgIface string WgIface string
WgPort int WgPort int
NetworkMonitor *bool NetworkMonitor bool
IFaceBlackList []string IFaceBlackList []string
DisableIPv6Discovery bool DisableIPv6Discovery bool
RosenpassEnabled bool RosenpassEnabled bool
@@ -99,9 +95,6 @@ type Config struct {
// DisableAutoConnect determines whether the client should not start with the service // DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility // it's set to false by default due to backwards compatibility
DisableAutoConnect bool DisableAutoConnect bool
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -311,21 +304,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor { if input.NetworkMonitor != nil && *input.NetworkMonitor != config.NetworkMonitor {
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor) log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
config.NetworkMonitor = input.NetworkMonitor config.NetworkMonitor = *input.NetworkMonitor
updated = true updated = true
} }
if config.NetworkMonitor == nil {
// enable network monitoring by default on windows and darwin clients
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
enabled := true
config.NetworkMonitor = &enabled
updated = true
}
}
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress { if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)", log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress) string(input.CustomDNSAddress), config.CustomDNSAddress)
@@ -373,18 +357,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
config.DNSRouteInterval = *input.DNSRouteInterval
updated = true
} else if config.DNSRouteInterval == 0 {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
return updated, nil return updated, nil
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
@@ -92,9 +91,6 @@ func (c *ConnectClient) RunOniOS(
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager, dnsManager dns.IosDnsManager,
) error { ) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
mobileDependency := MobileDependency{ mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor, FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener, NetworkChangeListener: networkChangeListener,
@@ -252,10 +248,8 @@ func (c *ConnectClient) run(
return wrapErr(err) return wrapErr(err)
} }
checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks) c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
c.engineMutex.Unlock() c.engineMutex.Unlock()
err = c.engine.Start() err = c.engine.Start()
@@ -309,10 +303,6 @@ func (c *ConnectClient) Engine() *Engine {
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
}
engineConf := &EngineConfig{ engineConf := &EngineConfig{
WgIfaceName: config.WgIface, WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address, WgAddr: peerConfig.Address,
@@ -320,14 +310,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableIPv6Discovery: config.DisableIPv6Discovery, DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: config.WgPort, WgPort: config.WgPort,
NetworkMonitor: nm, NetworkMonitor: config.NetworkMonitor,
SSHKey: []byte(config.SSHKey), SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs, NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress, CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled, RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive, RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -338,15 +327,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
engineConf.PreSharedKey = &preSharedKey engineConf.PreSharedKey = &preSharedKey
} }
port, err := freePort(config.WgPort)
if err != nil {
return nil, err
}
if port != config.WgPort {
log.Infof("using %d as wireguard port: %d is in use", port, config.WgPort)
}
engineConf.WgPort = port
return engineConf, nil return engineConf, nil
} }
@@ -396,20 +376,3 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
notifier, _ := sri.(signal.ConnStateNotifier) notifier, _ := sri.(signal.ConnStateNotifier)
return notifier return notifier
} }
func freePort(start int) (int, error) {
addr := net.UDPAddr{}
if start == 0 {
start = iface.DefaultWgPort
}
for x := start; x <= 65535; x++ {
addr.Port = x
conn, err := net.ListenUDP("udp", &addr)
if err != nil {
continue
}
conn.Close()
return x, nil
}
return 0, errors.New("no free ports")
}

View File

@@ -1,57 +0,0 @@
package internal
import (
"net"
"testing"
)
func Test_freePort(t *testing.T) {
tests := []struct {
name string
port int
want int
wantErr bool
}{
{
name: "available",
port: 51820,
want: 51820,
wantErr: false,
},
{
name: "notavailable",
port: 51830,
want: 51831,
wantErr: false,
},
{
name: "noports",
port: 65535,
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
if err != nil {
t.Errorf("freePort error = %v", err)
}
c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535})
if err != nil {
t.Errorf("freePort error = %v", err)
}
t.Run(tt.name, func(t *testing.T) {
got, err := freePort(tt.port)
if (err != nil) != tt.wantErr {
t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("freePort() = %v, want %v", got, tt.want)
}
})
c1.Close()
c2.Close()
}
}

View File

@@ -1,6 +0,0 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
)

View File

@@ -1,8 +0,0 @@
//go:build !android
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns
@@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil return networkManager, nil
} }
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() { if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
if checkStub() { if checkStub() {
return systemdManager, nil return systemdManager, nil
} else { } else {
@@ -116,10 +116,16 @@ func getOSDNSManagerType() (osManagerType, error) {
} }
} }
if strings.Contains(text, "resolvconf") { if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() { if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
var value string
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
if err == nil {
if value == systemdDbusResolvConfModeForeign {
return systemdManager, nil return systemdManager, nil
} }
}
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
}
return resolvConfManager, nil return resolvConfManager, nil
} }
} }

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns

View File

@@ -39,10 +39,6 @@ func (w *mocWGIface) Address() iface.WGAddress {
} }
} }
func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me")
}
func (w *mocWGIface) GetFilter() iface.PacketFilter { func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter return w.filter
} }
@@ -265,7 +261,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -343,7 +339,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil { if err != nil {
t.Errorf("build interface wireguard: %v", err) t.Errorf("build interface wireguard: %v", err)
return return
@@ -801,7 +797,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil { if err != nil {
t.Fatalf("build interface wireguard: %v", err) t.Fatalf("build interface wireguard: %v", err)
return nil, err return nil, err

View File

@@ -1,20 +0,0 @@
package dns
import (
"errors"
"fmt"
)
var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
}
func isSystemdResolvedRunning() bool {
return false
}
func isSystemdResolveConfMode() bool {
return false
}

View File

@@ -242,25 +242,3 @@ func getSystemdDbusProperty(property string, store any) error {
return v.Store(store) return v.Store(store)
} }
func isSystemdResolvedRunning() bool {
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
}
func isSystemdResolveConfMode() bool {
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
return false
}
var value string
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
return false
}
if value == systemdDbusResolvConfModeForeign {
return true
}
return false
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build !android
package dns package dns
@@ -14,6 +14,11 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)
func CheckUncleanShutdown(wgIface string) error { func CheckUncleanShutdown(wgIface string) error {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -78,11 +79,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}() }()
log.WithField("question", r.Question[0]).Trace("received an upstream question") log.WithField("question", r.Question[0]).Trace("received an upstream question")
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
@@ -264,11 +260,14 @@ func (u *upstreamResolverBase) disable(err error) {
return return
} }
// todo test the deactivation logic, it seems to affect the client
if runtime.GOOS != "ios" {
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate(err) u.deactivate(err)
u.disabled = true u.disabled = true
go u.waitUntilResponse() go u.waitUntilResponse()
} }
}
func (u *upstreamResolverBase) testNameserver(server string) error { func (u *upstreamResolverBase) testNameserver(server string) error {
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout) ctx, cancel := context.WithTimeout(u.ctx, probeTimeout)

View File

@@ -2,17 +2,12 @@
package dns package dns
import ( import "github.com/netbirdio/netbird/iface"
"net"
"github.com/netbirdio/netbird/iface"
)
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() iface.PacketFilter GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper GetDevice() *iface.DeviceWrapper

View File

@@ -4,13 +4,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"maps"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
"reflect" "reflect"
"runtime" "runtime"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -29,21 +27,17 @@ import (
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -94,8 +88,6 @@ type EngineConfig struct {
RosenpassPermissive bool RosenpassPermissive bool
ServerSSHAllowed bool ServerSSHAllowed bool
DNSRouteInterval time.Duration
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -107,8 +99,8 @@ type Engine struct {
// peerConns is a map that holds all the peers that are known to this peer // peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn peerConns map[string]*peer.Conn
beforePeerHook nbnet.AddHookFunc beforePeerHook peer.BeforeAddPeerHookFunc
afterPeerHook nbnet.RemoveHookFunc afterPeerHook peer.AfterRemovePeerHookFunc
// rpManager is a Rosenpass manager // rpManager is a Rosenpass manager
rpManager *rosenpass.Manager rpManager *rosenpass.Manager
@@ -126,7 +118,6 @@ type Engine struct {
// clientRoutes is the most recent list of clientRoutes received from the Management Service // clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap clientRoutes route.HAMap
clientRoutesMu sync.RWMutex
clientCtx context.Context clientCtx context.Context
clientCancel context.CancelFunc clientCancel context.CancelFunc
@@ -142,7 +133,7 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
networkMonitor *networkmonitor.NetworkMonitor networkWatcher *networkmonitor.NetworkWatcher
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server sshServer nbssh.Server
@@ -159,11 +150,6 @@ type Engine struct {
signalProbe *Probe signalProbe *Probe
relayProbe *Probe relayProbe *Probe
wgProbe *Probe wgProbe *Probe
wgConnWorker sync.WaitGroup
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -181,7 +167,6 @@ func NewEngine(
config *EngineConfig, config *EngineConfig,
mobileDep MobileDependency, mobileDep MobileDependency,
statusRecorder *peer.Status, statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine { ) *Engine {
return NewEngineWithProbes( return NewEngineWithProbes(
clientCtx, clientCtx,
@@ -195,7 +180,6 @@ func NewEngine(
nil, nil,
nil, nil,
nil, nil,
checks,
) )
} }
@@ -212,7 +196,6 @@ func NewEngineWithProbes(
signalProbe *Probe, signalProbe *Probe,
relayProbe *Probe, relayProbe *Probe,
wgProbe *Probe, wgProbe *Probe,
checks []*mgmProto.Checks,
) *Engine { ) *Engine {
return &Engine{ return &Engine{
@@ -229,11 +212,11 @@ func NewEngineWithProbes(
networkSerial: 0, networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer, sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
networkWatcher: networkmonitor.New(),
mgmProbe: mgmProbe, mgmProbe: mgmProbe,
signalProbe: signalProbe, signalProbe: signalProbe,
relayProbe: relayProbe, relayProbe: relayProbe,
wgProbe: wgProbe, wgProbe: wgProbe,
checks: checks,
} }
} }
@@ -246,26 +229,20 @@ func (e *Engine) Stop() error {
} }
// stopping network monitor first to avoid starting the engine again // stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil { e.networkWatcher.Stop()
e.networkMonitor.Stop()
}
log.Info("Network monitor: stopped")
err := e.removeAllPeers() err := e.removeAllPeers()
if err != nil { if err != nil {
return err return err
} }
e.clientRoutesMu.Lock()
e.clientRoutes = nil e.clientRoutes = nil
e.clientRoutesMu.Unlock()
// very ugly but we want to remove peers from the WireGuard interface first before removing interface. // very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously // Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
e.close() e.close()
e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine") log.Infof("stopped Netbird Engine")
return nil return nil
} }
@@ -282,6 +259,8 @@ func (e *Engine) Start() error {
} }
e.ctx, e.cancel = context.WithCancel(e.clientCtx) e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.config.WgPort)
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err) log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
@@ -289,9 +268,6 @@ func (e *Engine) Start() error {
} }
e.wgInterface = wgIface e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive { if e.config.RosenpassPermissive {
@@ -316,7 +292,7 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
@@ -368,8 +344,20 @@ func (e *Engine) Start() error {
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveProbeEvents() e.receiveProbeEvents()
if e.config.NetworkMonitor {
// starting network monitor at the very last to avoid disruptions // starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor() go e.networkWatcher.Start(e.ctx, func() {
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
})
} else {
log.Infof("Network monitor is disabled, not starting")
}
return nil return nil
} }
@@ -542,10 +530,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// todo update signal // todo update signal
} }
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
if update.GetNetworkMap() != nil { if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones // only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap()) err := e.updateNetworkMap(update.GetNetworkMap())
@@ -553,27 +537,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err return err
} }
} }
return nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update
if isChecksEqual(e.checks, checks) {
return nil
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil return nil
} }
@@ -589,8 +553,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} else { } else {
if sshConf.GetSshEnabled() { if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" { if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS) log.Warnf("running SSH server on Windows is not supported")
return nil return nil
} }
// start SSH server if it wasn't running // start SSH server if it wasn't running
@@ -663,14 +627,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it. // E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() { func (e *Engine) receiveManagementEvents() {
go func() { go func() {
info, err := system.GetInfoWithChecks(e.ctx, e.checks) err := e.mgmClient.Sync(e.ctx, e.handleSync)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client // We want to cancel the operation of the whole client
@@ -737,20 +694,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil return nil
} }
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers()) e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -792,6 +735,17 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
} }
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutes = clientRoutes
protoDNSConfig := networkMap.GetDNSConfig() protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil { if protoDNSConfig == nil {
@@ -819,24 +773,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
routes := make([]*route.Route, 0) routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes { for _, protoRoute := range protoRoutes {
var prefix netip.Prefix _, prefix, _ := route.ParseNetwork(protoRoute.Network)
if len(protoRoute.Domains) == 0 {
var err error
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
continue
}
}
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix,
Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID), NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer, Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric), Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade, Masquerade: protoRoute.Masquerade,
KeepRoute: protoRoute.KeepRoute,
} }
routes = append(routes, convertedRoute) routes = append(routes, convertedRoute)
} }
@@ -934,25 +879,18 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
} }
e.wgConnWorker.Add(1)
go e.connWorker(conn, peerKey) go e.connWorker(conn, peerKey)
} }
return nil return nil
} }
func (e *Engine) connWorker(conn *peer.Conn, peerKey string) { func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
defer e.wgConnWorker.Done()
for { for {
// randomize starting time a bit // randomize starting time a bit
min := 500 min := 500
max := 2000 max := 2000
duration := time.Duration(rand.Intn(max-min)+min) * time.Millisecond time.Sleep(time.Duration(rand.Intn(max-min)+min) * time.Millisecond)
select {
case <-e.ctx.Done():
return
case <-time.After(duration):
}
// if peer has been removed -> give up // if peer has been removed -> give up
if !e.peerExists(peerKey) { if !e.peerExists(peerKey) {
@@ -1039,6 +977,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
WgConfig: wgConfig, WgConfig: wgConfig,
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(),
RosenpassPubKey: e.getRosenpassPubKey(), RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(), RosenpassAddr: e.getRosenpassAddr(),
} }
@@ -1101,6 +1040,8 @@ func (e *Engine) receiveSignalEvents() {
return err return err
} }
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
var rosenpassPubKey []byte var rosenpassPubKey []byte
rosenpassAddr := "" rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil { if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1123,6 +1064,8 @@ func (e *Engine) receiveSignalEvents() {
return err return err
} }
conn.RegisterProtoSupportMeta(msg.GetBody().GetFeaturesSupported())
var rosenpassPubKey []byte var rosenpassPubKey []byte
rosenpassAddr := "" rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil { if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1145,8 +1088,7 @@ func (e *Engine) receiveSignalEvents() {
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err) log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
return err return err
} }
conn.OnRemoteCandidate(candidate)
conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
case sProto.Body_MODE: case sProto.Body_MODE:
} }
@@ -1260,8 +1202,7 @@ func (e *Engine) close() {
} }
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
info := system.GetInfo(e.ctx) netMap, err := e.mgmClient.GetNetworkMap()
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -1290,7 +1231,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
default: default:
} }
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes) return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
} }
func (e *Engine) wgInterfaceCreate() (err error) { func (e *Engine) wgInterfaceCreate() (err error) {
@@ -1341,17 +1282,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
// GetClientRoutes returns the current routes from the route map // GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() route.HAMap { func (e *Engine) GetClientRoutes() route.HAMap {
e.clientRoutesMu.RLock() return e.clientRoutes
defer e.clientRoutesMu.RUnlock()
return maps.Clone(e.clientRoutes)
} }
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes)) routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes { for id, v := range e.clientRoutes {
routes[id.NetID()] = v routes[id.NetID()] = v
@@ -1464,72 +1399,3 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
func (e *Engine) probeTURNs() []relay.ProbeResult { func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs) return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
} }
func (e *Engine) restartEngine() {
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
return
}
e.networkMonitor = networkmonitor.New()
go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
debounceTimer.Stop()
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
}
}()
}
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

View File

@@ -17,7 +17,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@@ -58,9 +57,9 @@ var (
) )
func TestEngine_SSH(t *testing.T) { func TestEngine_SSH(t *testing.T) {
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" { if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH") t.Skip("skipping TestEngine_SSH on Windows")
} }
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
@@ -78,7 +77,7 @@ func TestEngine_SSH(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
ServerSSHAllowed: true, ServerSSHAllowed: true,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -212,16 +211,16 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }
@@ -230,7 +229,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
type testCase struct { type testCase struct {
name string name string
@@ -394,7 +392,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client // feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse) updates := make(chan *mgmtProto.SyncResponse)
defer close(updates) defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error { syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates { for msg := range updates {
err := msgHandler(msg) err := msgHandler(msg)
if err != nil { if err != nil {
@@ -409,8 +407,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -568,13 +565,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@@ -738,14 +734,12 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
@@ -811,13 +805,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
ctx, cancel := context.WithCancel(CtxInitState(context.Background())) ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel() defer cancel()
sigServer, signalAddr, err := startSignal(t) sigServer, signalAddr, err := startSignal()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
defer sigServer.Stop() defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, dir) mgmtServer, mgmtAddr, err := startManagement(dir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1009,14 +1003,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort, WgPort: wgPort,
} }
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
e.ctx = ctx
return e, err
} }
func startSignal(t *testing.T) (*grpc.Server, string, error) { func startSignal() (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
@@ -1024,9 +1014,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err) log.Fatalf("failed to listen: %v", err)
} }
srv, err := signalServer.NewServer(otel.Meter("")) proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() { go func() {
if err = s.Serve(lis); err != nil { if err = s.Serve(lis); err != nil {
@@ -1037,9 +1025,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
return s, lis.Addr().String(), nil return s, lis.Addr().String(), nil
} }
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) { func startManagement(dataDir string) (*grpc.Server, string, error) {
t.Helper()
config := &server.Config{ config := &server.Config{
Stuns: []*server.Host{}, Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{}, TURNConfig: &server.TURNConfig{},
@@ -1056,25 +1042,23 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err return nil, "", err
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStoreFromJson(config.Datadir, nil)
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -2,20 +2,14 @@ package networkmonitor
import ( import (
"context" "context"
"errors"
"sync"
) )
var ErrStopped = errors.New("monitor has been stopped") // NetworkWatcher watches for changes in network configuration.
type NetworkWatcher struct {
// NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct {
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex
} }
// New creates a new network monitor. // New creates a new network monitor.
func New() *NetworkMonitor { func New() *NetworkWatcher {
return &NetworkMonitor{} return &NetworkWatcher{}
} }

View File

@@ -5,6 +5,8 @@ package networkmonitor
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip"
"syscall" "syscall"
"unsafe" "unsafe"
@@ -12,10 +14,10 @@ import (
"golang.org/x/net/route" "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("failed to open routing socket: %v", err)
@@ -29,7 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
default: default:
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
@@ -45,6 +47,24 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type { switch msg.Type {
// handle interface state changes
case unix.RTM_IFINFO:
ifinfo, err := parseInterfaceMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing interface message: %v", err)
continue
}
if msg.Flags&unix.IFF_UP != 0 {
continue
}
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) {
continue
}
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
callback()
// handle route changes // handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE: case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n]) route, err := parseRouteMessage(buf[:n])
@@ -64,11 +84,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
switch msg.Type { switch msg.Type {
case unix.RTM_ADD: case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback() callback()
case unix.RTM_DELETE: case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback() callback()
} }
} }
} }
@@ -76,7 +96,25 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
} }
} }
func parseRouteMessage(buf []byte) (*systemops.Route, error) { func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.InterfaceMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return msg, nil
}
func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err) return nil, fmt.Errorf("parse RIB: %v", err)
@@ -91,5 +129,5 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
} }
return systemops.MsgToRoute(msg) return routemanager.MsgToRoute(msg)
} }

View File

@@ -5,45 +5,48 @@ package networkmonitor
import ( import (
"context" "context"
"errors" "errors"
"fmt" "net"
"net/netip" "net/netip"
"runtime/debug" "runtime/debug"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager"
) )
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns. // Start begins watching for network changes and calls the callback function and stops when a change is detected.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) { func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) {
if ctx.Err() != nil { if nw.cancel != nil {
return ctx.Err() log.Warn("Network monitor: already running, stopping previous watcher")
nw.Stop()
}
if ctx.Err() != nil {
log.Info("Network monitor: not starting, context is already cancelled")
return
} }
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx) ctx, nw.cancel = context.WithCancel(ctx)
nw.mu.Unlock() defer nw.Stop()
nw.wg.Add(1) var nexthop4, nexthop6 netip.Addr
defer nw.wg.Done() var intf4, intf6 *net.Interface
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error { operation := func() error {
var errv4, errv6 error var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified()) nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified()) nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil { if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops") return errors.New("failed to get default next hops")
} }
if errv4 == nil { if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name) log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name)
} }
if errv6 == nil { if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name) log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name)
} }
// continue if either route was found // continue if either route was found
@@ -53,30 +56,27 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil { if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err) log.Errorf("Network monitor: failed to get default next hops: %v", err)
return
} }
// recover in case sys ops panic // recover in case sys ops panic
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack())) log.Errorf("Network monitor: panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
} }
}() }()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil { if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("check change: %w", err) log.Errorf("Network monitor: failed to start: %v", err)
} }
return nil
} }
// Stop stops the network monitor. // Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() { func (nw *NetworkWatcher) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil { if nw.cancel != nil {
nw.cancel() nw.cancel()
nw.wg.Wait() nw.cancel = nil
log.Info("Network monitor: stopped")
} }
} }

View File

@@ -6,22 +6,27 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip"
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil { if intfv4 == nil && intfv6 == nil {
return errors.New("no interfaces available") return errors.New("no interfaces available")
} }
linkChan := make(chan netlink.LinkUpdate)
done := make(chan struct{}) done := make(chan struct{})
defer close(done) defer close(done)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
return fmt.Errorf("subscribe to link updates: %v", err)
}
routeChan := make(chan netlink.RouteUpdate) routeChan := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(routeChan, done); err != nil { if err := netlink.RouteSubscribe(routeChan, done); err != nil {
return fmt.Errorf("subscribe to route updates: %v", err) return fmt.Errorf("subscribe to route updates: %v", err)
@@ -31,7 +36,26 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
// handle interface state changes
case update := <-linkChan:
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) {
continue
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
callback()
return nil
case syscall.RTM_NEWLINK:
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
callback()
return nil
}
}
// handle route changes // handle route changes
case route := <-routeChan: case route := <-routeChan:
@@ -43,12 +67,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
// triggered on added/replaced routes // triggered on added/replaced routes
case syscall.RTM_NEWROUTE: case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback() callback()
return nil return nil
case syscall.RTM_DELROUTE: case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) { if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback() callback()
return nil return nil
} }
} }

View File

@@ -4,9 +4,8 @@ package networkmonitor
import "context" import "context"
func (nw *NetworkMonitor) Start(context.Context, func()) error { func (nw *NetworkWatcher) Start(context.Context, func()) {
return nil
} }
func (nw *NetworkMonitor) Stop() { func (nw *NetworkWatcher) Stop() {
} }

View File

@@ -5,12 +5,11 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager"
) )
const ( const (
@@ -26,16 +25,20 @@ const (
const interval = 10 * time.Second const interval = 10 * time.Second
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
var neighborv4, neighborv6 *systemops.Neighbor var neighborv4, neighborv6 *routemanager.Neighbor
{ {
initialNeighbors, err := getNeighbors() initialNeighbors, err := getNeighbors()
if err != nil { if err != nil {
return fmt.Errorf("get neighbors: %w", err) return fmt.Errorf("get neighbors: %w", err)
} }
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors) if n, ok := initialNeighbors[nexthopv4]; ok {
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors) neighborv4 = &n
}
if n, ok := initialNeighbors[nexthopv6]; ok {
neighborv6 = &n
}
} }
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6) log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
@@ -45,31 +48,23 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
case <-ticker.C: case <-ticker.C:
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) { if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
go callback() callback()
return nil return nil
} }
} }
} }
} }
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
if n, ok := initialNeighbors[nexthop.IP]; ok &&
n.State != unreachable &&
n.State != incomplete &&
n.State != tbd {
return &n
}
return nil
}
func changed( func changed(
nexthopv4 systemops.Nexthop, nexthopv4 netip.Addr,
neighborv4 *systemops.Neighbor, intfv4 *net.Interface,
nexthopv6 systemops.Nexthop, neighborv4 *routemanager.Neighbor,
neighborv6 *systemops.Neighbor, nexthopv6 netip.Addr,
intfv6 *net.Interface,
neighborv6 *routemanager.Neighbor,
) bool { ) bool {
neighbors, err := getNeighbors() neighbors, err := getNeighbors()
if err != nil { if err != nil {
@@ -86,7 +81,7 @@ func changed(
return false return false
} }
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) { if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) {
return true return true
} }
@@ -94,74 +89,44 @@ func changed(
} }
// routeChanged checks if the default routes still point to our nexthop/interface // routeChanged checks if the default routes still point to our nexthop/interface
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool { func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool {
if !nexthop.IP.IsValid() { if !nexthop.IsValid() {
return false return false
} }
unspec := getUnspecifiedPrefix(nexthop.IP) var unspec netip.Prefix
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec) if nexthop.Is6() {
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
} else {
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n")) if r, ok := routes[unspec]; ok {
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 {
if !foundMatchingRoute { intf := "<nil>"
logRouteChange(nexthop.IP, intf) if r.Interface != nil {
intf = r.Interface.Name
}
log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf)
return true
}
} else {
log.Infof("network monitor: default route is gone")
return true return true
} }
return false return false
} }
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix { func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool {
if ip.Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
for _, r := range routes {
if r.Destination == unspec {
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
}
}
return defaultRoutes, foundMatchingRoute
}
func formatRouteInfo(r systemops.Route) string {
newIntf := "<nil>"
if r.Interface != nil {
newIntf = r.Interface.Name
}
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
}
func logRouteChange(ip netip.Addr, intf *net.Interface) {
oldIntf := "<nil>"
if intf != nil {
oldIntf = intf.Name
}
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
}
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
if neighbor == nil { if neighbor == nil {
return false return false
} }
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces // TODO: consider non-local nexthops, e.g. on point-to-point interfaces
if n, ok := neighbors[nexthop.IP]; ok { if n, ok := neighbors[nexthop]; ok {
if n.State == unreachable || n.State == incomplete { if n.State != reachable && n.State != permanent {
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State)) log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
return true return true
} else if n.InterfaceIndex != neighbor.InterfaceIndex { } else if n.InterfaceIndex != neighbor.InterfaceIndex {
@@ -185,13 +150,13 @@ func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, ne
return false return false
} }
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) { func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
entries, err := systemops.GetNeighbors() entries, err := routemanager.GetNeighbors()
if err != nil { if err != nil {
return nil, fmt.Errorf("get neighbors: %w", err) return nil, fmt.Errorf("get neighbors: %w", err)
} }
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries)) neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries))
for _, entry := range entries { for _, entry := range entries {
neighbours[entry.IPAddress] = entry neighbours[entry.IPAddress] = entry
} }
@@ -199,13 +164,18 @@ func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
return neighbours, nil return neighbours, nil
} }
func getRoutes() ([]systemops.Route, error) { func getRoutes() (map[netip.Prefix]routemanager.Route, error) {
entries, err := systemops.GetRoutes() entries, err := routemanager.GetRoutes()
if err != nil { if err != nil {
return nil, fmt.Errorf("get routes: %w", err) return nil, fmt.Errorf("get routes: %w", err)
} }
return entries, nil routes := make(map[netip.Prefix]routemanager.Route, len(entries))
for _, entry := range entries {
routes[entry.Destination] = entry
}
return routes, nil
} }
func stateFromInt(state uint8) string { func stateFromInt(state uint8) string {

View File

@@ -18,7 +18,7 @@ import (
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/route" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
@@ -68,6 +68,9 @@ type ConnConfig struct {
NATExternalIPs []string NATExternalIPs []string
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
UserspaceBind bool
// RosenpassPubKey is this peer's Rosenpass public key // RosenpassPubKey is this peer's Rosenpass public key
RosenpassPubKey []byte RosenpassPubKey []byte
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port) // RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
@@ -98,6 +101,9 @@ type IceCredentials struct {
Pwd string Pwd string
} }
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
type Conn struct { type Conn struct {
config ConnConfig config ConnConfig
mu sync.Mutex mu sync.Mutex
@@ -127,13 +133,30 @@ type Conn struct {
wgProxyFactory *wgproxy.Factory wgProxyFactory *wgproxy.Factory
wgProxy wgproxy.Proxy wgProxy wgproxy.Proxy
remoteModeCh chan ModeMessage
meta meta
adapter iface.TunAdapter adapter iface.TunAdapter
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
sentExtraSrflx bool sentExtraSrflx bool
remoteEndpoint *net.UDPAddr
remoteConn *ice.Conn
connID nbnet.ConnectionID connID nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc beforeAddPeerHooks []BeforeAddPeerHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc afterRemovePeerHooks []AfterRemovePeerHookFunc
}
// meta holds meta information about a connection
type meta struct {
protoSupport signal.FeaturesSupport
}
// ModeMessage represents a connection mode chosen by the peer
type ModeMessage struct {
// Direct indicates that it decided to use a direct connection
Direct bool
} }
// GetConf returns the connection config // GetConf returns the connection config
@@ -162,6 +185,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
remoteOffersCh: make(chan OfferAnswer), remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer),
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
remoteModeCh: make(chan ModeMessage, 1),
wgProxyFactory: wgProxyFactory, wgProxyFactory: wgProxyFactory,
adapter: adapter, adapter: adapter,
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
@@ -329,7 +353,7 @@ func (conn *Conn) Open(ctx context.Context) error {
err = conn.agent.GatherCandidates() err = conn.agent.GatherCandidates()
if err != nil { if err != nil {
return fmt.Errorf("gather candidates: %v", err) return err
} }
// will block until connection succeeded // will block until connection succeeded
@@ -346,12 +370,14 @@ func (conn *Conn) Open(ctx context.Context) error {
return err return err
} }
// dynamically set remote WireGuard port if other side specified a different one from the default one // dynamically set remote WireGuard port is other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 { if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort remoteWgPort = remoteOfferAnswer.WgListenPort
} }
conn.remoteConn = remoteConn
// the ice connection has been established successfully so we are ready to start the proxy // the ice connection has been established successfully so we are ready to start the proxy
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey, remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr) remoteOfferAnswer.RosenpassAddr)
@@ -376,11 +402,11 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay return candidate.Type() == ice.CandidateTypeRelay
} }
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) { func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
} }
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) { func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
} }
@@ -397,7 +423,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
var endpoint net.Addr var endpoint net.Addr
if isRelayCandidate(pair.Local) { if isRelayCandidate(pair.Local) {
log.Debugf("setup relay connection") log.Debugf("setup relay connection")
conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx) conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn) endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -409,6 +435,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
} }
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.remoteEndpoint = endpointUdpAddr
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
conn.connID = nbnet.GenerateConnID() conn.connID = nbnet.GenerateConnID()
@@ -460,10 +487,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
return nil, err return nil, err
} }
if runtime.GOOS == "ios" {
runtime.GC()
}
if conn.onConnected != nil { if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr) conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr)
} }
@@ -594,11 +617,7 @@ func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) err
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates // onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
// and then signals them to the remote peer // and then signals them to the remote peer
func (conn *Conn) onICECandidate(candidate ice.Candidate) { func (conn *Conn) onICECandidate(candidate ice.Candidate) {
// nil means candidate gathering has been ended if candidate != nil {
if candidate == nil {
return
}
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored // TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
log.Debugf("discovered local candidate %s", candidate.String()) log.Debugf("discovered local candidate %s", candidate.String())
go func() { go func() {
@@ -606,28 +625,33 @@ func (conn *Conn) onICECandidate(candidate ice.Candidate) {
if err != nil { if err != nil {
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err) log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
} }
}()
if !conn.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer // this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate) if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
relatedAdd := candidate.RelatedAddress()
extraSrflx, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
if err != nil { if err != nil {
log.Errorf("failed creating extra server reflexive candidate %s", err) log.Errorf("failed creating extra server reflexive candidate %s", err)
return return
} }
conn.sentExtraSrflx = true
go func() {
err = conn.signalCandidate(extraSrflx) err = conn.signalCandidate(extraSrflx)
if err != nil { if err != nil {
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err) log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
return
}
conn.sentExtraSrflx = true
} }
}() }()
} }
}
func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
@@ -751,7 +775,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String()) log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String())
go func() { go func() {
conn.mu.Lock() conn.mu.Lock()
@@ -773,21 +797,8 @@ func (conn *Conn) GetKey() string {
return conn.config.Key return conn.config.Key
} }
func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { // RegisterProtoSupportMeta register supported proto message in the connection metadata
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
return true protoSupport := signal.ParseFeaturesSupported(support)
} conn.meta.protoSupport = protoSupport
return false
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
} }

View File

@@ -1,7 +1,6 @@
package peer package peer
import ( import (
"context"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -36,7 +35,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -51,7 +50,7 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -88,7 +87,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -124,7 +123,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -154,7 +153,7 @@ func TestConn_Status(t *testing.T) {
} }
func TestConn_Close(t *testing.T) { func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()

View File

@@ -2,17 +2,14 @@ package peer
import ( import (
"errors" "errors"
"net/netip"
"sync" "sync"
"time" "time"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
) )
// State contains the latest state of a peer // State contains the latest state of a peer
@@ -40,25 +37,25 @@ type State struct {
// AddRoute add a single route to routes map // AddRoute add a single route to routes map
func (s *State) AddRoute(network string) { func (s *State) AddRoute(network string) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
if s.routes == nil { if s.routes == nil {
s.routes = make(map[string]struct{}) s.routes = make(map[string]struct{})
} }
s.routes[network] = struct{}{} s.routes[network] = struct{}{}
s.Mux.Unlock()
} }
// SetRoutes set state routes // SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) { func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
s.routes = routes s.routes = routes
s.Mux.Unlock()
} }
// DeleteRoute removes a route from the network amp // DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) { func (s *State) DeleteRoute(network string) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
delete(s.routes, network) delete(s.routes, network)
s.Mux.Unlock()
} }
// GetRoutes return routes map // GetRoutes return routes map
@@ -136,7 +133,6 @@ type Status struct {
rosenpassEnabled bool rosenpassEnabled bool
rosenpassPermissive bool rosenpassPermissive bool
nsGroupStates []NSGroupState nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain][]netip.Prefix
// To reduce the number of notification invocation this bool will be true when need to call the notification // To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@@ -152,7 +148,6 @@ func NewRecorder(mgmAddress string) *Status {
offlinePeers: make([]State, 0), offlinePeers: make([]State, 0),
notifier: newNotifier(), notifier: newNotifier(),
mgmAddress: mgmAddress, mgmAddress: mgmAddress,
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
} }
} }
@@ -193,7 +188,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
state, ok := d.peers[peerPubKey] state, ok := d.peers[peerPubKey]
if !ok { if !ok {
return State{}, iface.ErrPeerNotFound return State{}, errors.New("peer not found")
} }
return state, nil return state, nil
} }
@@ -434,18 +429,6 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates d.nsGroupStates = dnsStates
} }
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
d.resolvedDomainsStates[domain] = prefixes
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
d.mux.Lock()
defer d.mux.Unlock()
delete(d.resolvedDomainsStates, domain)
}
func (d *Status) GetRosenpassState() RosenpassState { func (d *Status) GetRosenpassState() RosenpassState {
return RosenpassState{ return RosenpassState{
d.rosenpassEnabled, d.rosenpassEnabled,
@@ -510,12 +493,6 @@ func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates return d.nsGroupStates
} }
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
// GetFullStatus gets full status // GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus { func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock() d.mux.Lock()

View File

@@ -170,7 +170,7 @@ func ProbeAll(
var wg sync.WaitGroup var wg sync.WaitGroup
for i, uri := range relays { for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second) ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel() defer cancel()
wg.Add(1) wg.Add(1)

View File

@@ -3,20 +3,19 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const minRangeBits = 7
type routerPeerStatus struct { type routerPeerStatus struct {
connected bool connected bool
relayed bool relayed bool
@@ -29,42 +28,33 @@ type routesUpdate struct {
routes []*route.Route routes []*route.Route
} }
// RouteHandler defines the interface for handling routes
type RouteHandler interface {
String() string
AddRoute(ctx context.Context) error
RemoveRoute() error
AddAllowedIPs(peerKey string) error
RemoveAllowedIPs() error
}
type clientNetwork struct { type clientNetwork struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc stop context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
routeUpdate chan routesUpdate routeUpdate chan routesUpdate
peerStateUpdate chan struct{} peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{} routePeersNotifiers map[string]chan struct{}
currentChosen *route.Route chosenRoute *route.Route
handler RouteHandler network netip.Prefix
updateSerial uint64 updateSerial uint64
} }
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &clientNetwork{
ctx: ctx, ctx: ctx,
cancel: cancel, stop: cancel,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder), network: network,
} }
return client return client
} }
@@ -96,8 +86,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
// * Metric: Routes with lower metrics (better) are prioritized. // * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred. // * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored. // * Direct connections: Routes with direct peer connections are favored.
// * Latency: Routes with lower latency are prioritized.
// * Stability: In case of equal scores, the currently active route (if any) is maintained. // * Stability: In case of equal scores, the currently active route (if any) is maintained.
// * Latency: Routes with lower latency are prioritized.
// //
// It returns the ID of the selected optimal route. // It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
@@ -106,8 +96,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
currScore := float64(0) currScore := float64(0)
currID := route.ID("") currID := route.ID("")
if c.currentChosen != nil { if c.chosenRoute != nil {
currID = c.currentChosen.ID currID = c.chosenRoute.ID
} }
for _, r := range c.routes { for _, r := range c.routes {
@@ -161,18 +151,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
peers = append(peers, r.Peer) peers = append(peers, r.Peer)
} }
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers) log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
case chosen != currID: case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes // we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore { if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore) log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
return currID return currID
} }
var p string var p string
if rt := c.routes[chosen]; rt != nil { if rt := c.routes[chosen]; rt != nil {
p = rt.Peer p = rt.Peer
} }
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler) log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network)
} }
return chosen return chosen
@@ -206,103 +196,98 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
} }
} }
func (c *clientNetwork) removeRouteFromWireguardPeer() error { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
c.removeStateRoute() state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Errorf("get peer state: %v", err)
}
if err := c.handler.RemoveAllowedIPs(); err != nil { state.DeleteRoute(c.network.String())
return fmt.Errorf("remove allowed IPs: %w", err) if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if state.ConnStatus != peer.StatusConnected {
return nil
}
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
if err != nil {
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
} }
return nil return nil
} }
func (c *clientNetwork) removeRouteFromPeerAndSystem() error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.currentChosen == nil { if c.chosenRoute != nil {
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
}
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("remove route: %v", err)
}
}
return nil return nil
} }
var merr *multierror.Error
if err := c.removeRouteFromWireguardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
routerPeerStatuses := c.getRouterPeerStatuses() routerPeerStatuses := c.getRouterPeerStatuses()
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses) chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system // If no route is chosen, remove the route from the peer and system
if newChosenID == "" { if chosen == "" {
if err := c.removeRouteFromPeerAndSystem(); err != nil { if err := c.removeRouteFromPeerAndSystem(); err != nil {
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err) return fmt.Errorf("remove route from peer and system: %v", err)
} }
c.currentChosen = nil c.chosenRoute = nil
return nil return nil
} }
// If the chosen route is the same as the current route, do nothing // If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID && if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
c.currentChosen.IsEqual(c.routes[newChosenID]) { if c.chosenRoute.IsEqual(c.routes[chosen]) {
return nil return nil
} }
}
if c.currentChosen == nil { if c.chosenRoute != nil {
// If they were not previously assigned to another peer, add routes to the system first // If a previous route exists, remove it from the peer
if err := c.handler.AddRoute(c.ctx); err != nil { if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("add route: %w", err) return fmt.Errorf("remove route from peer: %v", err)
} }
} else { } else {
// Otherwise, remove the allowed IPs from the previous peer first // otherwise add the route to the system
if err := c.removeRouteFromWireguardPeer(); err != nil { if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err)
} }
} }
c.currentChosen = c.routes[newChosenID] c.chosenRoute = c.routes[chosen]
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil { state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
} }
c.addStateRoute() if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
return nil return nil
} }
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() { go func() {
c.routeUpdate <- update c.routeUpdate <- update
@@ -333,23 +318,24 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
for { for {
select { select {
case <-c.ctx.Done(): case <-c.ctx.Done():
log.Debugf("Stopping watcher for network [%v]", c.handler) log.Debugf("stopping watcher for network %s", c.network)
if err := c.removeRouteFromPeerAndSystem(); err != nil { err := c.removeRouteFromPeerAndSystem()
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err) if err != nil {
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
} }
return return
case <-c.peerStateUpdate: case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem() err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil { if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
} }
case update := <-c.routeUpdate: case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial { if update.updateSerial < c.updateSerial {
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial) log.Warnf("Received a routes update with smaller serial number, ignoring it")
continue continue
} }
log.Debugf("Received a new client network route update for [%v]", c.handler) log.Debugf("Received a new client network route update for %s", c.network)
c.handleUpdate(update) c.handleUpdate(update)
@@ -357,7 +343,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
err := c.recalculateRouteAndUpdatePeerAndSystem() err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil { if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
} }
c.startPeersStatusChangeWatcher() c.startPeersStatusChangeWatcher()
@@ -365,9 +351,14 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
} }
} }
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler { func (c *clientNetwork) getAsInterface() *net.Interface {
if rt.IsDynamic() { intf, err := net.InterfaceByName(c.wgInterface.Name())
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder) if err != nil {
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
} }
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) }
return intf
} }

View File

@@ -5,7 +5,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -341,9 +340,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
// create new clientNetwork // create new clientNetwork
client := &clientNetwork{ client := &clientNetwork{
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil), network: netip.MustParsePrefix("192.168.0.0/24"),
routes: tc.existingRoutes, routes: tc.existingRoutes,
currentChosen: currentRoute, chosenRoute: currentRoute,
} }
chosenRoute := client.getBestRouteFromStatuses(tc.statuses) chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@@ -1,378 +0,0 @@
package dynamic
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const (
DefaultInterval = time.Minute
minInterval = 2 * time.Second
failureInterval = 5 * time.Second
addAllowedIP = "add allowed IP %s: %w"
)
type domainMap map[domain.Domain][]netip.Prefix
type resolveResult struct {
domain domain.Domain
prefix netip.Prefix
err error
}
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
interval time.Duration
dynamicDomains domainMap
mu sync.Mutex
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
}
func NewRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
}
}
func (r *Route) String() string {
s, err := r.route.Domains.String()
if err != nil {
return r.route.Domains.PunycodeString()
}
return s
}
func (r *Route) AddRoute(ctx context.Context) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
ctx, r.cancel = context.WithCancel(ctx)
go r.startResolver(ctx)
return nil
}
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
func (r *Route) RemoveRoute() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
var merr *multierror.Error
for domain, prefixes := range r.dynamicDomains {
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
}
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
r.statusRecorder.DeleteResolvedDomainsStates(domain)
}
r.dynamicDomains = domainMap{}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) AddAllowedIPs(peerKey string) error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for domain, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
}
r.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) RemoveAllowedIPs() error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for _, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
}
r.currentPeerKey = ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) startResolver(ctx context.Context) {
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
interval := r.interval
if interval < minInterval {
interval = minInterval
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
if interval > failureInterval {
ticker.Reset(failureInterval)
}
}
for {
select {
case <-ctx.Done():
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
return
case <-ticker.C:
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
// Use a lower ticker interval if the update fails
if interval > failureInterval {
ticker.Reset(failureInterval)
}
} else if interval > failureInterval {
// Reset to the original interval if the update succeeds
ticker.Reset(interval)
}
}
}
}
func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil {
return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
return fmt.Errorf("update dynamic routes: %w", err)
}
return nil
}
func (r *Route) resolveDomains() (domainMap, error) {
results := make(chan resolveResult)
go r.resolve(results)
resolved := domainMap{}
var merr *multierror.Error
for result := range results {
if result.err != nil {
merr = multierror.Append(merr, result.err)
} else {
resolved[result.domain] = append(resolved[result.domain], result.prefix)
}
}
return resolved, nberrors.FormatErrorOrNil(merr)
}
func (r *Route) resolve(results chan resolveResult) {
var wg sync.WaitGroup
for _, d := range r.route.Domains {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
ips, err := net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
return
}
results <- resolveResult{domain: domain, prefix: prefix}
}
}(d)
}
wg.Wait()
close(results)
}
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
r.mu.Lock()
defer r.mu.Unlock()
if ctx.Err() != nil {
return ctx.Err()
}
var merr *multierror.Error
for domain, newPrefixes := range newDomains {
oldPrefixes := r.dynamicDomains[domain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
addedPrefixes, err := r.addRoutes(domain, toAdd)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(addedPrefixes) > 0 {
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
}
removedPrefixes, err := r.removeRoutes(toRemove)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(removedPrefixes) > 0 {
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
}
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
var addedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
continue
}
if r.currentPeerKey != "" {
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
addedPrefixes = append(addedPrefixes, prefix)
}
return addedPrefixes, merr.ErrorOrNil()
}
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
if r.route.KeepRoute {
return nil, nil
}
var removedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
if r.currentPeerKey != "" {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
removedPrefixes = append(removedPrefixes, prefix)
}
return removedPrefixes, merr.ErrorOrNil()
}
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
return fmt.Errorf(addAllowedIP, prefix, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
prefixSet := make(map[netip.Prefix]struct{})
for _, prefix := range oldPrefixes {
prefixSet[prefix] = struct{}{}
}
for _, prefix := range removedPrefixes {
delete(prefixSet, prefix)
}
for _, prefix := range addedPrefixes {
prefixSet[prefix] = struct{}{}
}
var combinedPrefixes []netip.Prefix
for prefix := range prefixSet {
combinedPrefixes = append(combinedPrefixes, prefix)
}
return combinedPrefixes
}

View File

@@ -2,23 +2,18 @@ package routemanager
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"net/url" "net/url"
"runtime" "runtime"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -26,9 +21,14 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
// nolint:unused
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
@@ -46,65 +46,25 @@ type DefaultManager struct {
clientNetworks map[route.HAUniqueID]*clientNetwork clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector routeSelector *routeselector.RouteSelector
serverRouter serverRouter serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
pubKey string pubKey string
notifier *notifier notifier *notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
} }
func NewManager( func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
wgInterface *iface.WGIface,
statusRecorder *peer.Status,
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface)
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork), clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(), routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
notifier: newNotifier(), notifier: newNotifier(),
} }
dm.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ any) (any, error) {
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ any) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
},
)
dm.allowedIPsRefCounter = refcounter.New(
func(prefix netip.Prefix, peerKey string) (string, error) {
// save peerKey to use it in the remove function
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
},
func(prefix netip.Prefix, peerKey string) error {
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
return err
}
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
}
return nil
},
)
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes) cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr) dm.notifier.setInitialClientRoutes(cr)
@@ -113,12 +73,12 @@ func NewManager(
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() { if nbnet.CustomRoutingDisabled() {
return nil, nil, nil return nil, nil, nil
} }
if err := m.sysOps.CleanupRouting(); err != nil { if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err) log.Warnf("Failed cleaning up routing: %v", err)
} }
@@ -126,7 +86,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
signalAddress := m.statusRecorder.GetSignalState().URL signalAddress := m.statusRecorder.GetSignalState().URL
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips) beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err) return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
@@ -150,19 +110,8 @@ func (m *DefaultManager) Stop() {
m.serverRouter.cleanUp() m.serverRouter.cleanUp()
} }
if m.routeRefCounter != nil {
if err := m.routeRefCounter.Flush(); err != nil {
log.Errorf("Error flushing route ref counter: %v", err)
}
}
if m.allowedIPsRefCounter != nil {
if err := m.allowedIPsRefCounter.Flush(); err != nil {
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
}
}
if !nbnet.CustomRoutingDisabled() { if !nbnet.CustomRoutingDisabled() {
if err := m.sysOps.CleanupRouting(); err != nil { if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err) log.Errorf("Error cleaning up routing: %v", err)
} else { } else {
log.Info("Routing cleanup complete") log.Info("Routing cleanup complete")
@@ -236,7 +185,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue continue
} }
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
@@ -248,7 +197,7 @@ func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks { for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok { if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id) log.Debugf("Stopping client network watcher, %s", id)
client.cancel() client.stop()
delete(m.clientNetworks, id) delete(m.clientNetworks, id)
} }
} }
@@ -261,7 +210,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
} }
@@ -279,7 +228,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
ownNetworkIDs := make(map[route.HAUniqueID]bool) ownNetworkIDs := make(map[route.HAUniqueID]bool)
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
haID := newRoute.GetHAUniqueID() haID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey { if newRoute.Peer == m.pubKey {
ownNetworkIDs[haID] = true ownNetworkIDs[haID] = true
// only linux is supported for now // only linux is supported for now
@@ -292,9 +241,9 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
} }
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
haID := newRoute.GetHAUniqueID() haID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[haID] { if !ownNetworkIDs[haID] {
if !isRouteSupported(newRoute) { if !isPrefixSupported(newRoute.Network) {
continue continue
} }
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute) newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
@@ -306,23 +255,23 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes) _, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0, len(crMap)) rs := make([]*route.Route, 0)
for _, routes := range crMap { for _, routes := range crMap {
rs = append(rs, routes...) rs = append(rs, routes...)
} }
return rs return rs
} }
func isRouteSupported(route *route.Route) bool { func isPrefixSupported(prefix netip.Prefix) bool {
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() { if !nbnet.CustomRoutingDisabled() {
return true return true
} }
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported // If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management // we skip this prefix management
if route.Network.Bits() <= vars.MinRangeBits { if prefix.Bits() <= minRangeBits {
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
version.NetbirdVersion(), route.Network) version.NetbirdVersion(), prefix)
return false return false
} }
return true return true

View File

@@ -407,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
@@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm") statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil) routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
_, _, err = routeManager.Init() _, _, err = routeManager.Init()
@@ -436,7 +436,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.NoError(t, err, "should update routes") require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected expectedWatchers := testCase.clientNetworkWatchersExpected
if testCase.clientNetworkWatchersExpectedAllowed != 0 { if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
} }
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")

View File

@@ -6,10 +6,10 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
) )
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
@@ -20,7 +20,7 @@ type MockManager struct {
StopFunc func() StopFunc func()
} }
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil return nil, nil, nil
} }

View File

@@ -1,155 +0,0 @@
package refcounter
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
var ErrIgnore = errors.New("ignore")
type Ref[O any] struct {
Count int
Out O
}
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
type Counter[I, O any] struct {
// refCountMap keeps track of the reference Ref for prefixes
refCountMap map[netip.Prefix]Ref[O]
refCountMu sync.Mutex
// idMap keeps track of the prefixes associated with an ID for removal
idMap map[string][]netip.Prefix
idMu sync.Mutex
add AddFunc[I, O]
remove RemoveFunc[I, O]
}
// New creates a new Counter instance
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
return &Counter[I, O]{
refCountMap: map[netip.Prefix]Ref[O]{},
idMap: map[string][]netip.Prefix{},
add: add,
remove: remove,
}
}
// Increment increments the reference count for the given prefix.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref := rm.refCountMap[prefix]
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
// Call AddFunc only if it's a new prefix
if ref.Count == 0 {
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
out, err := rm.add(prefix, in)
if errors.Is(err, ErrIgnore) {
return ref, nil
}
if err != nil {
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
}
ref.Out = out
}
ref.Count++
rm.refCountMap[prefix] = ref
return ref, nil
}
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
ref, err := rm.Increment(prefix, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
rm.idMap[id] = append(rm.idMap[id], prefix)
return ref, nil
}
// Decrement decrements the reference count for the given prefix.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref, ok := rm.refCountMap[prefix]
if !ok {
log.Tracef("No reference found for prefix %s", prefix)
return ref, nil
}
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
if ref.Count == 1 {
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
if err := rm.remove(prefix, ref.Out); err != nil {
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
}
delete(rm.refCountMap, prefix)
} else {
ref.Count--
rm.refCountMap[prefix] = ref
}
return ref, nil
}
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for _, prefix := range rm.idMap[id] {
if _, err := rm.Decrement(prefix); err != nil {
merr = multierror.Append(merr, err)
}
}
delete(rm.idMap, id)
return nberrors.FormatErrorOrNil(merr)
}
// Flush removes all references and calls RemoveFunc for each prefix.
func (rm *Counter[I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for prefix := range rm.refCountMap {
log.Tracef("Removing for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.remove(prefix, ref.Out); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]Ref[O]{}
rm.idMap = map[string][]netip.Prefix{}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,7 +0,0 @@
package refcounter
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
type RouteRefCounter = Counter[any, any]
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
type AllowedIPsRefCounter = Counter[string, string]

View File

@@ -0,0 +1,127 @@
//go:build !android && !ios
package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
type ref struct {
count int
nexthop netip.Addr
intf *net.Interface
}
type RouteManager struct {
// refCountMap keeps track of the reference ref for prefixes
refCountMap map[netip.Prefix]ref
// prefixMap keeps track of the prefixes associated with a connection ID for removal
prefixMap map[nbnet.ConnectionID][]netip.Prefix
addRoute AddRouteFunc
removeRoute RemoveRouteFunc
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap
return &RouteManager{
refCountMap: map[netip.Prefix]ref{},
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
addRoute: addRoute,
removeRoute: removeRoute,
}
}
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
ref := rm.refCountMap[prefix]
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
// Add route to the system, only if it's a new prefix
if ref.count == 0 {
log.Debugf("Adding route for prefix %s", prefix)
nexthop, intf, err := rm.addRoute(prefix)
if errors.Is(err, ErrRouteNotFound) {
return nil
}
if errors.Is(err, ErrRouteNotAllowed) {
log.Debugf("Adding route for prefix %s: %s", prefix, err)
}
if err != nil {
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
}
ref.nexthop = nexthop
ref.intf = intf
}
ref.count++
rm.refCountMap[prefix] = ref
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
return nil
}
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
prefixes, ok := rm.prefixMap[connID]
if !ok {
log.Debugf("No prefixes found for connection ID %s", connID)
return nil
}
var result *multierror.Error
for _, prefix := range prefixes {
ref := rm.refCountMap[prefix]
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
if ref.count == 1 {
log.Debugf("Removing route for prefix %s", prefix)
// TODO: don't fail if the route is not found
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
continue
}
delete(rm.refCountMap, prefix)
} else {
ref.count--
rm.refCountMap[prefix] = ref
}
}
delete(rm.prefixMap, connID)
return result.ErrorOrNil()
}
// Flush removes all references and routes from the system
func (rm *RouteManager) Flush() error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
var result *multierror.Error
for prefix := range rm.refCountMap {
log.Debugf("Removing route for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]ref{}
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
return result.ErrorOrNil()
}

View File

@@ -12,7 +12,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -71,7 +70,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
} }
if len(m.routes) > 0 { if len(m.routes) > 0 {
err := systemops.EnableIPForwarding() err := enableIPForwarding()
if err != nil { if err != nil {
return err return err
} }
@@ -89,7 +88,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route) routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
if err != nil { if err != nil {
return fmt.Errorf("parse prefix: %w", err) return fmt.Errorf("parse prefix: %w", err)
} }
@@ -118,7 +117,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route) routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
if err != nil { if err != nil {
return fmt.Errorf("parse prefix: %w", err) return fmt.Errorf("parse prefix: %w", err)
} }
@@ -134,13 +133,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
if state.Routes == nil { if state.Routes == nil {
state.Routes = map[string]struct{}{} state.Routes = map[string]struct{}{}
} }
state.Routes[route.Network.String()] = struct{}{}
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
state.Routes[routeStr] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state) m.statusRecorder.UpdateLocalPeerState(state)
return nil return nil
@@ -151,7 +144,7 @@ func (m *defaultServerRouter) cleanUp() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
for _, r := range m.routes { for _, r := range m.routes {
routerPair, err := routeToRouterPair(r) routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
if err != nil { if err != nil {
log.Errorf("Failed to convert route to router pair: %v", err) log.Errorf("Failed to convert route to router pair: %v", err)
continue continue
@@ -169,27 +162,15 @@ func (m *defaultServerRouter) cleanUp() {
m.statusRecorder.UpdateLocalPeerState(state) m.statusRecorder.UpdateLocalPeerState(state)
} }
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
// TODO: add ipv6 parsed, err := netip.ParsePrefix(source)
source := getDefaultPrefix(route.Network) if err != nil {
return firewall.RouterPair{}, err
destination := route.Network.Masked().String()
if route.IsDynamic() {
// TODO: add ipv6
destination = "0.0.0.0/0"
} }
return firewall.RouterPair{ return firewall.RouterPair{
ID: string(route.ID), ID: string(route.ID),
Source: source.String(), Source: parsed.String(),
Destination: destination, Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade, Masquerade: route.Masquerade,
}, nil }, nil
} }
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}

View File

@@ -1,57 +0,0 @@
package static
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
}
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
}
}
// Route route methods
func (r *Route) String() string {
return r.route.Network.String()
}
func (r *Route) AddRoute(context.Context) error {
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
return err
}
func (r *Route) RemoveRoute() error {
_, err := r.routeRefCounter.Decrement(r.route.Network)
return err
}
func (r *Route) AddAllowedIPs(peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
r.route.Network,
ref.Out,
)
}
return nil
}
func (r *Route) RemoveAllowedIPs() error {
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
return err
}

View File

@@ -1,103 +0,0 @@
// go:build !android
package sysctl
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/iface"
)
const (
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
// Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := Set(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = Set(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := Set(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, nberrors.FormatErrorOrNil(result)
}
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
// Cleanup resets sysctl settings to their original values.
func Cleanup(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := Set(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -0,0 +1,414 @@
//go:build !android && !ios
package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRouteNotFound = errors.New("route not found")
var ErrRouteNotAllowed = errors.New("route not allowed")
// TODO: fix: for default our wg address now appears as the default gw
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
defaultGateway, _, err := GetNextHop(addr)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(defaultGateway) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
if defaultGateway.Is6() {
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
gatewayHop, intf, err := GetNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
}
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
r, err := netroute.New()
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
}
return addr.Unmap(), intf, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
}
return addr, intf, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
}
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, nil, ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := GetNextHop(addr)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
exitIntf := intf
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
exitIntf = initialIntf
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
}
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == defaultv6 {
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, netip.Addr{}, intf)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
} else if prefix == defaultv6 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
return removeFromRouteTable(prefix, netip.Addr{}, intf)
}
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return &prefix, nil
}
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {
nexthop, intf = initialNextHopV6, initialIntfV6
}
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
},
removeFromRouteTable,
)
return setupHooks(*routeManager, initAddresses)
}
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
if routeManager == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := routeManager.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := getPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := routeManager.RemoveRouteRef(connID); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return result.ErrorOrNil()
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}

View File

@@ -1,18 +0,0 @@
//go:build darwin || dragonfly || netbsd || openbsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
return false
}

View File

@@ -1,19 +0,0 @@
//go:build: freebsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
return false
}

View File

@@ -1,27 +0,0 @@
package systemops
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
type Nexthop struct {
IP netip.Addr
Intf *net.Interface
}
type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
return &SysOps{
wgInterface: wgInterface,
}
}

View File

@@ -1,473 +0,0 @@
//go:build !android && !ios
package systemops
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
refCounter := refcounter.New(
func(prefix netip.Prefix, _ any) (Nexthop, error) {
initialNexthop := initialNextHopV4
if prefix.Addr().Is6() {
initialNexthop = initialNextHopV6
}
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore
}
return nexthop, err
},
r.removeFromRouteTable,
)
r.refCounter = refCounter
return r.setupHooks(initAddresses)
}
func (r *SysOps) cleanupRefCounter() error {
if r.refCounter == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr)
if err != nil {
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
exitNextHop := Nexthop{
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err
}
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == vars.Defaultv6 {
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return r.addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
} else if prefix == vars.Defaultv6 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
}
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(result)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}
func GetNextHop(ip netip.Addr) (Nexthop, error) {
r, err := netroute.New()
if err != nil {
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
}
return Nexthop{
IP: addr,
Intf: intf,
}, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
}
return Nexthop{
IP: addr,
Intf: intf,
}, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
zone := intf.Name
if runtime.GOOS == "windows" {
zone = strconv.Itoa(intf.Index)
}
log.Tracef("Adding zone %s to address %s", zone, addr)
addr = addr.WithZone(zone)
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()
if err != nil {
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{}
}
return isVpnRoute(addr, vpnRoutes, localRoutes)
}
func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) {
vpnPrefixMap := map[netip.Prefix]struct{}{}
for _, prefix := range vpnRoutes {
vpnPrefixMap[prefix] = struct{}{}
}
// remove vpnRoute duplicates
for _, prefix := range localRoutes {
delete(vpnPrefixMap, prefix)
}
var longestPrefix netip.Prefix
var isVpn bool
combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes))
copy(combinedRoutes, vpnRoutes)
copy(combinedRoutes[len(vpnRoutes):], localRoutes)
for _, prefix := range combinedRoutes {
// Ignore the default route, it has special handling
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
// Longest prefix match
if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() {
longestPrefix = prefix
_, isVpn = vpnPrefixMap[prefix]
}
}
}
if !longestPrefix.IsValid() {
// No route matched
return false, netip.Prefix{}
}
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}

View File

@@ -1,38 +0,0 @@
//go:build ios || android
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
return nil
}
func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}

View File

@@ -1,28 +0,0 @@
//go:build !linux && !ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericAddVPNRoute(prefix, intf)
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericRemoveVPNRoute(prefix, intf)
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return getRoutesFromTable()
}

View File

@@ -0,0 +1,33 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@@ -1,6 +1,6 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd //go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops package routemanager
import ( import (
"errors" "errors"
@@ -43,7 +43,8 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type) return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
} }
if filterRoutesByFlags(m.Flags) { if m.Flags&syscall.RTF_UP == 0 ||
m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
continue continue
} }
@@ -92,7 +93,7 @@ func toNetIP(a route.Addr) netip.Addr {
case *route.Inet6Addr: case *route.Inet6Addr:
ip := netip.AddrFrom16(t.IP) ip := netip.AddrFrom16(t.IP)
if t.ZoneID != 0 { if t.ZoneID != 0 {
ip = ip.WithZone(strconv.Itoa(t.ZoneID)) ip.WithZone(strconv.Itoa(t.ZoneID))
} }
return ip return ip
default: default:
@@ -100,7 +101,6 @@ func toNetIP(a route.Addr) netip.Addr {
} }
} }
// ones returns the number of leading ones in the mask.
func ones(a route.Addr) (int, error) { func ones(a route.Addr) (int, error) {
switch t := a.(type) { switch t := a.(type) {
case *route.Inet4Addr: case *route.Inet4Addr:
@@ -114,7 +114,6 @@ func ones(a route.Addr) (int, error) {
} }
} }
// MsgToRoute converts a route message to a Route.
func MsgToRoute(msg *route.RouteMessage) (*Route, error) { func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2] dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]

View File

@@ -0,0 +1,57 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package routemanager
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/route"
)
func TestBits(t *testing.T) {
tests := []struct {
name string
addr route.Addr
want int
wantErr bool
}{
{
name: "IPv4 all ones",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
want: 32,
},
{
name: "IPv4 normal mask",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
want: 24,
},
{
name: "IPv6 all ones",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
want: 128,
},
{
name: "IPv6 normal mask",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
want: 64,
},
{
name: "Unsupported type",
addr: &route.LinkAddr{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ones(tt.addr)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

View File

@@ -1,6 +1,6 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd //go:build darwin && !ios
package systemops package routemanager
import ( import (
"fmt" "fmt"
@@ -13,41 +13,48 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { var routeManager *RouteManager
return r.setupRefCounter(initAddresses)
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
} }
func (r *SysOps) CleanupRouting() error { func cleanupRouting() error {
return r.cleanupRefCounter() return cleanupRoutingWithRouteManager(routeManager)
} }
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return r.routeCmd("add", prefix, nexthop) return routeCmd("add", prefix, nexthop, intf)
} }
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return r.routeCmd("delete", prefix, nexthop) return routeCmd("delete", prefix, nexthop, intf)
} }
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error { func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
inet := "-inet" inet := "-inet"
if prefix.Addr().Is6() {
inet = "-inet6"
}
network := prefix.String() network := prefix.String()
if prefix.IsSingleIP() { if prefix.IsSingleIP() {
network = prefix.Addr().String() network = prefix.Addr().String()
} }
if prefix.Addr().Is6() {
inet = "-inet6"
// Special case for IPv6 split default route, pointing to the wg interface fails
// TODO: Remove once we have IPv6 support on the interface
if prefix.Bits() == 1 {
intf = &net.Interface{Name: "lo0"}
}
}
args := []string{"-n", action, inet, network} args := []string{"-n", action, inet, network}
if nexthop.IP.IsValid() { if nexthop.IsValid() {
args = append(args, nexthop.IP.Unmap().String()) args = append(args, nexthop.Unmap().String())
} else if nexthop.Intf != nil { } else if intf != nil {
args = append(args, "-interface", nexthop.Intf.Name) args = append(args, "-interface", intf.Name)
} }
if err := retryRouteCmd(args); err != nil { if err := retryRouteCmd(args); err != nil {

View File

@@ -1,6 +1,6 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd //go:build !ios
package systemops package routemanager
import ( import (
"fmt" "fmt"
@@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/route"
) )
var expectedVPNint = "utun100" var expectedVPNint = "utun100"
@@ -36,15 +35,13 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0") baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"} intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {
wg.Add(1) wg.Add(1)
go func(ip netip.Addr) { go func(ip netip.Addr) {
defer wg.Done() defer wg.Done()
prefix := netip.PrefixFrom(ip, 32) prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err) t.Errorf("Failed to add route for %s: %v", prefix, err)
} }
}(baseIP) }(baseIP)
@@ -60,7 +57,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) { go func(ip netip.Addr) {
defer wg.Done() defer wg.Done()
prefix := netip.PrefixFrom(ip, 32) prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err) t.Errorf("Failed to remove route for %s: %v", prefix, err)
} }
}(baseIP) }(baseIP)
@@ -70,53 +67,6 @@ func TestConcurrentRoutes(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestBits(t *testing.T) {
tests := []struct {
name string
addr route.Addr
want int
wantErr bool
}{
{
name: "IPv4 all ones",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
want: 32,
},
{
name: "IPv4 normal mask",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
want: 24,
},
{
name: "IPv6 all ones",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
want: 128,
},
{
name: "IPv6 normal mask",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
want: 64,
},
{
name: "Unsupported type",
addr: &route.LinkAddr{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ones(tt.addr)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper() t.Helper()

View File

@@ -0,0 +1,33 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@@ -1,6 +1,6 @@
//go:build !android //go:build !android
package systemops package routemanager
import ( import (
"bufio" "bufio"
@@ -9,15 +9,16 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strconv"
"strings"
"syscall" "syscall"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -32,10 +33,16 @@ const (
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting. // ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "net.ipv4.ip_forward" ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
) )
var ErrTableIDExists = errors.New("ID exists with different name") var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified // originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int var originalSysctl map[string]int
@@ -75,7 +82,7 @@ func getSetupRules() []ruleParams {
} }
} }
// SetupRouting establishes the routing configuration for the VPN, including essential rules // setupRouting establishes the routing configuration for the VPN, including essential rules
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. // to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
// //
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over // Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
@@ -85,17 +92,17 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy() { if isLegacy() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses) return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
} }
if err = addRoutingTableName(); err != nil { if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err) log.Errorf("Error adding routing table name: %v", err)
} }
originalValues, err := sysctl.Setup(r.wgInterface) originalValues, err := setupSysctl(wgIface)
if err != nil { if err != nil {
log.Errorf("Error setting up sysctl: %v", err) log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true sysctlFailed = true
@@ -104,7 +111,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
defer func() { defer func() {
if err != nil { if err != nil {
if cleanErr := r.CleanupRouting(); cleanErr != nil { if cleanErr := cleanupRouting(); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr) log.Errorf("Error cleaning up routing: %v", cleanErr)
} }
} }
@@ -116,7 +123,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
if errors.Is(err, syscall.EOPNOTSUPP) { if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true) setIsLegacy(true)
return r.setupRefCounter(initAddresses) return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
} }
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
} }
@@ -125,12 +132,12 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
return nil, nil, nil return nil, nil, nil
} }
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state. // It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process. // The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting() error { func cleanupRouting() error {
if isLegacy() { if isLegacy() {
return r.cleanupRefCounter() return cleanupRoutingWithRouteManager(routeManager)
} }
var result *multierror.Error var result *multierror.Error
@@ -149,58 +156,58 @@ func (r *SysOps) CleanupRouting() error {
} }
} }
if err := sysctl.Cleanup(originalSysctl); err != nil { if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err)) result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
} }
originalSysctl = nil originalSysctl = nil
sysctlFailed = false sysctlFailed = false
return nberrors.FormatErrorOrNil(result) return result.ErrorOrNil()
} }
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN) return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
} }
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN) return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
} }
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() { if isLegacy() {
return r.genericAddVPNRoute(prefix, intf) return genericAddVPNRoute(prefix, intf)
} }
if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) { if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)") log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
} }
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support // TODO remove this once we have ipv6 support
if prefix == vars.Defaultv4 { if prefix == defaultv4 {
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err) return fmt.Errorf("add blackhole: %w", err)
} }
} }
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add route: %w", err) return fmt.Errorf("add route: %w", err)
} }
return nil return nil
} }
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() { if isLegacy() {
return r.genericRemoveVPNRoute(prefix, intf) return genericRemoveVPNRoute(prefix, intf)
} }
// TODO remove this once we have ipv6 support // TODO remove this once we have ipv6 support
if prefix == vars.Defaultv4 { if prefix == defaultv4 {
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err) return fmt.Errorf("remove unreachable route: %w", err)
} }
} }
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove route: %w", err) return fmt.Errorf("remove route: %w", err)
} }
return nil return nil
@@ -248,7 +255,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
} }
// addRoute adds a route to a specific routing table identified by tableID. // addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
route := &netlink.Route{ route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE, Scope: netlink.SCOPE_UNIVERSE,
Table: tableID, Table: tableID,
@@ -261,7 +268,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
} }
route.Dst = ipNet route.Dst = ipNet
if err := addNextHop(nexthop, route); err != nil { if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err) return fmt.Errorf("add gateway and device: %w", err)
} }
@@ -320,7 +327,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
} }
// removeRoute removes a route from a specific routing table identified by tableID. // removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String()) _, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil { if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err) return fmt.Errorf("parse prefix %s: %w", prefix, err)
@@ -333,7 +340,7 @@ func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
Dst: ipNet, Dst: ipNet,
} }
if err := addNextHop(nexthop, route); err != nil { if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err) return fmt.Errorf("add gateway and device: %w", err)
} }
@@ -366,11 +373,11 @@ func flushRoutes(tableID, family int) error {
} }
} }
return nberrors.FormatErrorOrNil(result) return result.ErrorOrNil()
} }
func EnableIPForwarding() error { func enableIPForwarding() error {
_, err := sysctl.Set(ipv4ForwardingPath, 1, false) _, err := setSysctl(ipv4ForwardingPath, 1, false)
return err return err
} }
@@ -474,19 +481,19 @@ func removeRule(params ruleParams) error {
} }
// addNextHop adds the gateway and device to the route. // addNextHop adds the gateway and device to the route.
func addNextHop(nexthop Nexthop, route *netlink.Route) error { func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
if nexthop.Intf != nil { if intf != nil {
route.LinkIndex = nexthop.Intf.Index route.LinkIndex = intf.Index
} }
if nexthop.IP.IsValid() { if addr.IsValid() {
route.Gw = nexthop.IP.AsSlice() route.Gw = addr.AsSlice()
// if zone is set, it means the gateway is a link-local address, so we set the link index // if zone is set, it means the gateway is a link-local address, so we set the link index
if nexthop.IP.Zone() != "" && nexthop.Intf == nil { if addr.Zone() != "" && intf == nil {
link, err := netlink.LinkByName(nexthop.IP.Zone()) link, err := netlink.LinkByName(addr.Zone())
if err != nil { if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err) return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
} }
route.LinkIndex = link.Attrs().Index route.LinkIndex = link.Attrs().Index
} }
@@ -502,9 +509,82 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6 return netlink.FAMILY_V6
} }
func hasSeparateRouting() ([]netip.Prefix, error) { // setupSysctl configures sysctl settings for RP filtering and source validation.
if isLegacy() { func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
return getRoutesFromTable() keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
} }
return nil, ErrRoutingIsSeparate
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
} }

View File

@@ -1,6 +1,6 @@
//go:build !android //go:build !android
package systemops package routemanager
import ( import (
"errors" "errors"
@@ -14,8 +14,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
) )
var expectedVPNint = "wgtest0" var expectedVPNint = "wgtest0"
@@ -140,7 +138,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
if dstIPNet.String() == "0.0.0.0/0" { if dstIPNet.String() == "0.0.0.0/0" {
var err error var err error
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { if err != nil && !errors.Is(err, ErrRouteNotFound) {
t.Logf("Failed to fetch original gateway: %v", err) t.Logf("Failed to fetch original gateway: %v", err)
} }
@@ -195,7 +193,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
} }
} }
return nil, 0, vars.ErrRouteNotFound return nil, 0, ErrRouteNotFound
} }
func setupDummyInterfacesAndRoutes(t *testing.T) { func setupDummyInterfacesAndRoutes(t *testing.T) {

View File

@@ -0,0 +1,24 @@
//go:build !linux && !ios
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericAddVPNRoute(prefix, intf)
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericRemoveVPNRoute(prefix, intf)
}

View File

@@ -1,6 +1,6 @@
//go:build !android && !ios //go:build !android && !ios
package systemops package routemanager
import ( import (
"bytes" "bytes"
@@ -49,10 +49,6 @@ func TestAddRemoveRoutes(t *testing.T) {
} }
for n, testCase := range testCases { for n, testCase := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", testCase.name, " on freebsd")
}
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
@@ -61,26 +57,23 @@ func TestAddRemoveRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, wgInterface)
r := NewSysOps(wgInterface)
_, _, err = r.SetupRouting(nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting()) assert.NoError(t, cleanupRouting())
}) })
index, err := net.InterfaceByName(wgInterface.Name()) index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = r.AddVPNRoute(testCase.prefix, intf) err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err") require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard { if testCase.shouldRouteToWireguard {
@@ -91,19 +84,19 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix) exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err") require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard { if exists && testCase.shouldRouteToWireguard {
err = r.RemoveVPNRoute(testCase.prefix, intf) err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err") require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixNexthop, err := GetNextHop(testCase.prefix.Addr()) prefixGateway, _, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err") require.NoError(t, err, "GetNextHop should not return err")
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err) require.NoError(t, err)
if testCase.shouldBeRemoved { if testCase.shouldBeRemoved {
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway") require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
} else { } else {
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway") require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
} }
} }
}) })
@@ -111,14 +104,11 @@ func TestAddRemoveRoutes(t *testing.T) {
} }
func TestGetNextHop(t *testing.T) { func TestGetNextHop(t *testing.T) {
if runtime.GOOS == "freebsd" { gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Skip("skipping on freebsd")
}
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil { if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err) t.Fatal("shouldn't return error when fetching the gateway: ", err)
} }
if !nexthop.IP.IsValid() { if !gateway.IsValid() {
t.Fatal("should return a gateway") t.Fatal("should return a gateway")
} }
addresses, err := net.InterfaceAddrs() addresses, err := net.InterfaceAddrs()
@@ -140,24 +130,24 @@ func TestGetNextHop(t *testing.T) {
} }
} }
localIP, err := GetNextHop(testingPrefix.Addr()) localIP, _, err := GetNextHop(testingPrefix.Addr())
if err != nil { if err != nil {
t.Fatal("shouldn't return error: ", err) t.Fatal("shouldn't return error: ", err)
} }
if !localIP.IP.IsValid() { if !localIP.IsValid() {
t.Fatal("should return a gateway for local network") t.Fatal("should return a gateway for local network")
} }
if localIP.IP.String() == nexthop.IP.String() { if localIP.String() == gateway.String() {
t.Fatal("local IP should not match with gateway IP") t.Fatal("local ip should not match with gateway IP")
} }
if localIP.IP.String() != testingIP { if localIP.String() != testingIP {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String()) t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
} }
} }
func TestAddExistAndRemoveRoute(t *testing.T) { func TestAddExistAndRemoveRoute(t *testing.T) {
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop) t.Log("defaultGateway: ", defaultGateway)
if err != nil { if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err) t.Fatal("shouldn't return error when fetching the gateway: ", err)
} }
@@ -174,7 +164,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
}, },
{ {
name: "Should Not Add Route if overlaps with default gateway", name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"), prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
shouldAddRoute: false, shouldAddRoute: false,
}, },
{ {
@@ -213,7 +203,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
@@ -224,16 +214,14 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface)
// Prepare the environment // Prepare the environment
if testCase.preExistingPrefix.IsValid() { if testCase.preExistingPrefix.IsValid() {
err := r.AddVPNRoute(testCase.preExistingPrefix, intf) err := addVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route") require.NoError(t, err, "should not return err when adding pre-existing route")
} }
// Add the route // Add the route
err = r.AddVPNRoute(testCase.prefix, intf) err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route") require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute { if testCase.shouldAddRoute {
@@ -243,7 +231,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist") require.True(t, ok, "route should exist")
// remove route again if added // remove route again if added
err = r.RemoveVPNRoute(testCase.prefix, intf) err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err") require.NoError(t, err, "should not return err")
} }
@@ -307,22 +295,19 @@ func TestExistsInRouteTable(t *testing.T) {
var addressPrefixes []netip.Prefix var addressPrefixes []netip.Prefix
for _, address := range addresses { for _, address := range addresses {
p := netip.MustParsePrefix(address.String()) p := netip.MustParsePrefix(address.String())
if p.Addr().Is6() {
switch {
case p.Addr().Is6():
continue continue
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
continue
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
continue
// FreeBSD loopback 127/8 is not added to the routing table
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
continue
default:
addressPrefixes = append(addressPrefixes, p.Masked())
} }
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
continue
}
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
continue
}
addressPrefixes = append(addressPrefixes, p.Masked())
} }
for _, prefix := range addressPrefixes { for _, prefix := range addressPrefixes {
@@ -345,7 +330,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
require.NoError(t, err) require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface") require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create() err = wgInterface.Create()
@@ -358,52 +343,65 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
return wgInterface return wgInterface
} }
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper()
err := r.AddVPNRoute(prefix, intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func setupTestEnv(t *testing.T) { func setupTestEnv(t *testing.T) {
t.Helper() t.Helper()
setupDummyInterfacesAndRoutes(t) setupDummyInterfacesAndRoutes(t)
wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, wgInterface.Close()) assert.NoError(t, wgIface.Close())
}) })
r := NewSysOps(wgInterface) _, _, err := setupRouting(nil, wgIface)
_, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting()) assert.NoError(t, cleanupRouting())
}) })
index, err := net.InterfaceByName(wgInterface.Name()) index, err := net.InterfaceByName(wgIface.Name())
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
// default route exists in main table and vpn table // default route exists in main table and vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf) err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table // 10.0.0.0/8 route exists in main table and vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf) err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table // 10.10.0.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf) err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table // 127.0.10.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf) err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table // unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf) err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
} }
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
@@ -412,133 +410,11 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
return return
} }
prefixNexthop, err := GetNextHop(prefix.Addr()) prefixGateway, _, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err") require.NoError(t, err, "GetNextHop should not return err")
if invert { if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP") assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
} else { } else {
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP") assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
}
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
} }
} }

View File

@@ -1,11 +1,10 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly //go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package systemops package routemanager
import ( import (
"fmt" "fmt"
"net" "net"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -86,10 +85,6 @@ var testCases = []testCase{
func TestRouting(t *testing.T) { func TestRouting(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", tc.name, " on freebsd")
}
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t) setupTestEnv(t)

View File

@@ -1,6 +1,6 @@
//go:build windows //go:build windows
package systemops package routemanager
import ( import (
"fmt" "fmt"
@@ -17,7 +17,8 @@ import (
"github.com/yusufpapurcu/wmi" "github.com/yusufpapurcu/wmi"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
) )
type MSFT_NetRoute struct { type MSFT_NetRoute struct {
@@ -56,42 +57,14 @@ var prefixList []netip.Prefix
var lastUpdate time.Time var lastUpdate time.Time
var mux = sync.Mutex{} var mux = sync.Mutex{}
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { var routeManager *RouteManager
return r.setupRefCounter(initAddresses)
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
} }
func (r *SysOps) CleanupRouting() error { func cleanupRouting() error {
return r.cleanupRefCounter() return cleanupRoutingWithRouteManager(routeManager)
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
zone, err := strconv.Atoi(nexthop.IP.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
nexthop.Intf = &net.Interface{Index: zone}
}
return addRouteCmd(prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
} }
func getRoutesFromTable() ([]netip.Prefix, error) { func getRoutesFromTable() ([]netip.Prefix, error) {
@@ -120,7 +93,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
func GetRoutes() ([]Route, error) { func GetRoutes() ([]Route, error) {
var entries []MSFT_NetRoute var entries []MSFT_NetRoute
query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute` query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil { if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
return nil, fmt.Errorf("get routes: %w", err) return nil, fmt.Errorf("get routes: %w", err)
} }
@@ -145,10 +118,6 @@ func GetRoutes() ([]Route, error) {
Index: int(entry.InterfaceIndex), Index: int(entry.InterfaceIndex),
Name: entry.InterfaceAlias, Name: entry.InterfaceAlias,
} }
if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) {
nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex)))
}
} }
routes = append(routes, Route{ routes = append(routes, Route{
@@ -188,12 +157,11 @@ func GetNeighbors() ([]Neighbor, error) {
return neighbors, nil return neighbors, nil
} }
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error { func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
args := []string{"add", prefix.String()} args := []string{"add", prefix.String()}
if nexthop.IP.IsValid() { if nexthop.IsValid() {
ip := nexthop.IP.WithZone("") args = append(args, nexthop.Unmap().String())
args = append(args, ip.Unmap().String())
} else { } else {
addr := "0.0.0.0" addr := "0.0.0.0"
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
@@ -202,8 +170,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args = append(args, addr) args = append(args, addr)
} }
if nexthop.Intf != nil { if intf != nil {
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index)) args = append(args, "if", strconv.Itoa(intf.Index))
} }
routeCmd := uspfilter.GetSystem32Command("route") routeCmd := uspfilter.GetSystem32Command("route")
@@ -217,6 +185,37 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
return nil return nil
} }
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("")
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func isCacheDisabled() bool { func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
} }

View File

@@ -1,4 +1,4 @@
package systemops package routemanager
import ( import (
"context" "context"
@@ -29,7 +29,7 @@ type FindNetRouteOutput struct {
InterfaceIndex int `json:"InterfaceIndex"` InterfaceIndex int `json:"InterfaceIndex"`
InterfaceAlias string `json:"InterfaceAlias"` InterfaceAlias string `json:"InterfaceAlias"`
AddressFamily int `json:"AddressFamily"` AddressFamily int `json:"AddressFamily"`
NextHop string `json:"Nexthop"` NextHop string `json:"NextHop"`
DestinationPrefix string `json:"DestinationPrefix"` DestinationPrefix string `json:"DestinationPrefix"`
} }
@@ -166,7 +166,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
host, _, err := net.SplitHostPort(destination) host, _, err := net.SplitHostPort(destination)
require.NoError(t, err) require.NoError(t, err)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host) script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
out, err := exec.Command("powershell", "-Command", script).Output() out, err := exec.Command("powershell", "-Command", script).Output()
require.NoError(t, err, "Failed to execute Find-NetRoute") require.NoError(t, err, "Failed to execute Find-NetRoute")
@@ -207,7 +207,7 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
} }
func fetchOriginalGateway() (*RouteInfo, error) { func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json") cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)

View File

@@ -1,29 +0,0 @@
package util
import (
"fmt"
"net"
"net/netip"
)
// GetPrefixFromIP returns a netip.Prefix from a net.IP address.
func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return prefix, nil
}

View File

@@ -1,16 +0,0 @@
package vars
import (
"errors"
"net/netip"
)
const MinRangeBits = 7
var (
ErrRouteNotFound = errors.New("route not found")
ErrRouteNotAllowed = errors.New("route not allowed")
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
)

View File

@@ -3,11 +3,11 @@ package routeselector
import ( import (
"fmt" "fmt"
"slices" "slices"
"strings"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors"
route "github.com/netbirdio/netbird/route" route "github.com/netbirdio/netbird/route"
) )
@@ -30,10 +30,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
var err *multierror.Error var multiErr *multierror.Error
for _, route := range routes { for _, route := range routes {
if !slices.Contains(allRoutes, route) { if !slices.Contains(allRoutes, route) {
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue continue
} }
@@ -41,7 +41,11 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
} }
rs.selectAll = false rs.selectAll = false
return errors.FormatErrorOrNil(err) if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
} }
// SelectAllRoutes sets the selector to select all routes. // SelectAllRoutes sets the selector to select all routes.
@@ -61,17 +65,21 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
} }
} }
var err *multierror.Error var multiErr *multierror.Error
for _, route := range routes { for _, route := range routes {
if !slices.Contains(allRoutes, route) { if !slices.Contains(allRoutes, route) {
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue continue
} }
delete(rs.selectedRoutes, route) delete(rs.selectedRoutes, route)
} }
return errors.FormatErrorOrNil(err) if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
} }
// DeselectAllRoutes deselects all routes, effectively disabling route selection. // DeselectAllRoutes deselects all routes, effectively disabling route selection.
@@ -103,3 +111,18 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
} }
return filtered return filtered
} }
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@@ -261,15 +261,15 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
routes := route.HAMap{ routes := route.HAMap{
"route1|10.0.0.0/8": {}, "route1-10.0.0.0/8": {},
"route2|192.168.0.0/16": {}, "route2-192.168.0.0/16": {},
"route3|172.16.0.0/12": {}, "route3-172.16.0.0/12": {},
} }
filtered := rs.FilterSelected(routes) filtered := rs.FilterSelected(routes)
assert.Equal(t, route.HAMap{ assert.Equal(t, route.HAMap{
"route1|10.0.0.0/8": {}, "route1-10.0.0.0/8": {},
"route2|192.168.0.0/16": {}, "route2-192.168.0.0/16": {},
}, filtered) }, filtered)
} }

View File

@@ -1,7 +1,7 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html>
<head> <head>
<meta content="width=device-width, initial-scale=1" name="viewport"/> <meta name="viewport" content="width=device-width, initial-scale=1"/>
<style> <style>
body { body {
display: flex; display: flex;
@@ -50,17 +50,16 @@
color: black; color: black;
} }
</style> </style>
<title>NetBird Login Successful</title>
</head> </head>
<body> <body>
<div class="container"> <div class="container">
<div class="logo"> <div class="logo">
<img alt="netbird_logo" src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png"> <img src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png">
</div> </div>
<br> <br>
{{ if .Error }} {{ if .Error }}
<svg height="50" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg"> <svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
<circle cx="50" cy="50" fill="none" r="45" stroke="red" stroke-width="3"/> <circle cx="50" cy="50" r="45" fill="none" stroke="red" stroke-width="3"/>
<path d="M30 30 L70 70 M30 70 L70 30" fill="none" stroke="red" stroke-width="3"/> <path d="M30 30 L70 70 M30 70 L70 30" fill="none" stroke="red" stroke-width="3"/>
</svg> </svg>
<div class="content"> <div class="content">
@@ -70,8 +69,8 @@
{{ .Error }}. {{ .Error }}.
</div> </div>
{{ else }} {{ else }}
<svg height="50" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg"> <svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
<circle cx="50" cy="50" fill="none" r="45" stroke="#5cb85c" stroke-width="3"/> <circle cx="50" cy="50" r="45" fill="none" stroke="#5cb85c" stroke-width="3"/>
<path d="M30 50 L45 65 L70 35" fill="none" stroke="#5cb85c" stroke-width="5"/> <path d="M30 50 L45 65 L70 35" fill="none" stroke="#5cb85c" stroke-width="5"/>
</svg> </svg>
<div class="content"> <div class="content">

View File

@@ -1,17 +1,15 @@
package wgproxy package wgproxy
import "context"
type Factory struct { type Factory struct {
wgPort int wgPort int
ebpfProxy Proxy ebpfProxy Proxy
} }
func (w *Factory) GetProxy(ctx context.Context) Proxy { func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil { if w.ebpfProxy != nil {
return w.ebpfProxy return w.ebpfProxy
} }
return NewWGUserSpaceProxy(ctx, w.wgPort) return NewWGUserSpaceProxy(w.wgPort)
} }
func (w *Factory) Free() error { func (w *Factory) Free() error {

View File

@@ -3,19 +3,13 @@
package wgproxy package wgproxy
import ( import (
"context"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory { func NewFactory(wgPort int) *Factory {
f := &Factory{wgPort: wgPort} f := &Factory{wgPort: wgPort}
if userspace { ebpfProxy := NewWGEBPFProxy(wgPort)
return f
}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen() err := ebpfProxy.listen()
if err != nil { if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)

View File

@@ -2,8 +2,6 @@
package wgproxy package wgproxy
import "context" func NewFactory(wgPort int) *Factory {
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort} return &Factory{wgPort: wgPort}
} }

View File

@@ -3,7 +3,6 @@
package wgproxy package wgproxy
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -23,13 +22,9 @@ import (
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int localWGListenPort int
ebpfManager ebpfMgr.Manager
lastUsedPort uint16
turnConnStore map[uint16]net.Conn turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex turnConnMutex sync.Mutex
@@ -39,7 +34,7 @@ type WGEBPFProxy struct {
} }
// NewWGEBPFProxy create new WGEBPFProxy instance // NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy { func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy") log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{ wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
@@ -47,8 +42,6 @@ func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
lastUsedPort: 0, lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn), turnConnStore: make(map[uint16]net.Conn),
} }
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy return wgProxy
} }
@@ -137,18 +130,14 @@ func (p *WGEBPFProxy) Free() error {
} }
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
buf := make([]byte, 1500)
var err error
defer func() { defer func() {
log.Tracef("stop proxying turn traffic to wg: %d", endpointPort)
p.removeTurnConn(endpointPort) p.removeTurnConn(endpointPort)
}() }()
buf := make([]byte, 1500)
for { for {
select { n, err := remoteConn.Read(buf)
case <-p.ctx.Done():
return
default:
var n int
n, err = remoteConn.Read(buf)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
@@ -157,8 +146,10 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
} }
err = p.sendPkg(buf[:n], endpointPort) err = p.sendPkg(buf[:n], endpointPort)
if err != nil { if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err) if err == io.EOF {
return
} }
log.Errorf("failed to write out turn pkg to local conn: %v", err)
} }
} }
} }
@@ -167,10 +158,6 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
select {
case <-p.ctx.Done():
return
default:
n, addr, err := p.conn.ReadFromUDP(buf) n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil { if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err) log.Errorf("failed to read UDP pkg from WG: %s", err)
@@ -191,7 +178,6 @@ func (p *WGEBPFProxy) proxyToRemote() {
} }
} }
} }
}
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) { func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock() p.turnConnMutex.Lock()
@@ -206,11 +192,9 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
} }
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) { func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Tracef("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock() p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock() defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID) delete(p.turnConnStore, turnConnID)
} }
func (p *WGEBPFProxy) nextFreePort() (uint16, error) { func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
@@ -286,17 +270,20 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
err := udpH.SetNetworkLayerForChecksum(ipH) err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil { if err != nil {
return fmt.Errorf("set network layer for checksum: %w", err) log.Errorf("set network layer for checksum: %s", err)
return err
} }
layerBuffer := gopacket.NewSerializeBuffer() layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload) err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil { if err != nil {
return fmt.Errorf("serialize layers: %w", err) log.Errorf("serialize layers: %s", err)
return err
} }
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
return fmt.Errorf("write to raw conn: %w", err) log.Errorf("write to raw conn: %s", err)
return err
} }
return nil return nil
} }

View File

@@ -3,12 +3,11 @@
package wgproxy package wgproxy
import ( import (
"context"
"testing" "testing"
) )
func TestWGEBPFProxy_connStore(t *testing.T) { func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil) p, _ := wgProxy.storeTurnConn(nil)
if p != 1 { if p != 1 {
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535 wgProxy.lastUsedPort = 65535
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ { for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)

View File

@@ -21,12 +21,12 @@ type WGUserSpaceProxy struct {
} }
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy // NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy { func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{ p := &WGUserSpaceProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
} }
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(context.Background())
return p return p
} }
@@ -35,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn p.remoteConn = turnConn
var err error var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil { if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err) log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err return nil, err

Some files were not shown because too many files have changed in this diff Show More