Compare commits

...

58 Commits

Author SHA1 Message Date
Maycon Santos
b56f61bf1b [misc] fix relay exposed address test (#3931) 2025-06-05 15:44:44 +02:00
Viktor Liu
64f111923e [client] Increase stun status probe timeout (#3930) 2025-06-05 15:22:59 +02:00
Abdul Latif
122a89c02b [misc] remove error causing dnf config-manager add (#3925) 2025-06-05 14:28:19 +02:00
Robert Neumann
c6cceba381 Update getting-started-with-zitadel.sh - fix zitadel user console (#3446) 2025-06-05 14:16:04 +02:00
Ghazy Abdallah
6c0cdb6ed1 [misc] fix: traefik relay accessibility (#3696) 2025-06-05 14:15:01 +02:00
Viktor Liu
84354951d3 [client] Add systemd netbird logs to debug bundle (#3917) 2025-06-05 13:54:15 +02:00
Viktor Liu
55957a1960 [client] Run registerdns before flushing (#3926)
* Run registerdns before flushing

* Disable WINS, dynamic updates and registration
2025-06-05 12:40:23 +02:00
Viktor Liu
df82a45d99 [client] Improve dns match trace log (#3928) 2025-06-05 12:39:58 +02:00
Zoltan Papp
9424b88db2 [client] Add output similar to wg show to the debug package (#3922) 2025-06-05 11:51:39 +02:00
Viktor Liu
609654eee7 [client] Allow userspace local forwarding to internal interfaces if requested (#3884) 2025-06-04 18:12:48 +02:00
Bethuel Mmbaga
b604c66140 [management] Add postgres support for activity event store (#3890) 2025-06-04 17:38:49 +03:00
Viktor Liu
ea4d13e96d [client] Use platform-native routing APIs for freeBSD, macOS and Windows 2025-06-04 16:28:58 +02:00
Pedro Maia Costa
87148c503f [management] support account retrieval and creation by private domain (#3825)
* [management] sys initiator save user (#3911)

* [management] activity events with multiple external account users (#3914)
2025-06-04 11:21:31 +01:00
Viktor Liu
0cd36baf67 [client] Allow the netbird service to log to console (#3916) 2025-06-03 13:09:39 +02:00
Viktor Liu
06980e7fa0 [client] Apply routes right away instead of on peer connection (#3907) 2025-06-03 10:53:39 +02:00
Viktor Liu
1ce4ee0cef [client] Add block inbound flag to disallow inbound connections of any kind (#3897) 2025-06-03 10:53:27 +02:00
Viktor Liu
f367925496 [client] Log duplicate client ui pid (#3915) 2025-06-03 10:52:10 +02:00
hakansa
616b19c064 [client] Add "Deselect All" Menu Item to Exit Node Menu (#3877)
* [client] Enhance exit node menu functionality with deselect all option

* Hide exit nodes before removal in recreateExitNodeMenu

* recreateExitNodeMenu adding mutex locks

* Refetch exit nodes after deselecting all in exit node menu
2025-06-03 09:49:13 +02:00
Zoltan Papp
af27aaf9af [client] Refactor peer state change subscription mechanism (#3910)
* Refactor peer state change subscription mechanism

Because the code generated new channel for every single event, was easy to miss notification.
Use single channel.

* Fix lint

* Avoid potential deadlock

* Fix test

* Add context

* Fix test
2025-06-03 09:20:33 +02:00
Maycon Santos
35287f8241 [misc] Fail linter workflows on codespell failures (#3913)
* Fail linter workflows on codespell failures

* testing workflow

* remove test
2025-06-03 00:37:51 +02:00
Pedro Maia Costa
07b220d91b [management] REST client impersonation (#3879) 2025-06-02 22:11:28 +02:00
Viktor Liu
41cd4952f1 [client] Apply return traffic rules only if firewall is stateless (#3895) 2025-06-02 12:11:54 +02:00
Zoltan Papp
f16f0c7831 [client] Fix HA router switch (#3889)
* Fix HA router switch.

- Simplify the notification filter logic.
Always send notification if a state has been changed

- Remove IP changes check because we never modify

* Notify only the proper listeners

* Fix test

* Fix TestGetPeerStateChangeNotifierLogic test

* Before lazy connection, when the peer disconnected, the status switched to disconnected.
After implementing lazy connection, the peer state is connecting, so we did not decrease the reference counters on the routes.

* When switch to idle notify the route mgr
2025-06-01 16:08:27 +02:00
Zoltan Papp
aa07b3b87b Fix deadlock (#3904) 2025-05-30 23:38:02 +02:00
Bethuel Mmbaga
2bef214cc0 [management] Fix user groups propagation (#3902) 2025-05-30 18:12:30 +03:00
hakansa
cfb2d82352 [client] Refactor exclude list handling to use a map for permanent connections (#3901)
[client] Refactor exclude list handling to use a map for permanent connections (#3901)
2025-05-30 16:54:49 +03:00
Bethuel Mmbaga
684501fd35 [management] Prevent deletion of peers linked to network routers (#3881)
- Prevent deletion of peers linked to network routers
- Add API endpoint to list all network routers
2025-05-29 18:50:00 +03:00
Zoltan Papp
0492c1724a [client, android] Fix/notifier threading (#3807)
- Fix potential deadlocks
- When adding a listener, immediately notify with the last known IP and fqdn.
2025-05-27 17:12:04 +02:00
Zoltan Papp
6f436e57b5 [server-test] Install libs for i386 tests (#3887)
Install libs for i386 tests
2025-05-27 16:42:06 +02:00
Bethuel Mmbaga
a0d28f9851 [management] Reset test containers after cleanup (#3885) 2025-05-27 14:42:00 +03:00
Zoltan Papp
cdd27a9fe5 [client, android] Fix/android enable server route (#3806)
Enable the server route; otherwise, the manager throws an error and the engine will restart.
2025-05-27 13:32:54 +02:00
Bethuel Mmbaga
5523040acd [management] Add correlated network traffic event schema (#3680) 2025-05-27 13:47:53 +03:00
M. Essam
670446d42e [management/client/rest] Fix panic on unknown errors (#3865) 2025-05-25 16:57:34 +02:00
Pedro Maia Costa
5bed6777d5 [management] force account id on save groups update (#3850) 2025-05-23 14:42:42 +01:00
Pascal Fischer
a0482ebc7b [client] avoid overwriting state manager on iOS (#3870) 2025-05-23 14:04:12 +02:00
Bethuel Mmbaga
2a89d6e47a [management] Extend nameserver match domain validation (#3864)
* Enhance match domain validation logic and add test cases

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove the leading dot and root dot support ns regex

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove support for wildcard ns match domain

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 23:16:19 +02:00
Bethuel Mmbaga
24f932b2ce [management] Update traffic events pagination filters (#3857) 2025-05-22 16:28:14 +03:00
Pedro Maia Costa
c03435061c [management] lazy connection account setting (#3855) 2025-05-22 14:09:00 +01:00
Misha Bragin
8e948739f1 Fix CLA link in the PR template (#3860) 2025-05-22 10:38:58 +02:00
Maycon Santos
9b53cad752 [misc] add CLA note (#3859) 2025-05-21 22:40:36 +02:00
Zoltan Papp
802a18167c [client] Do not reconnect to mgm server in case of handler error (#3856)
* Do not reconnect to mgm server in case of handler error
Set to nil the flow grpc client to nil

* Better error handling
2025-05-21 20:18:21 +02:00
hakansa
e9108ffe6c [client] Add latest gzipped rotated log file to the debug bundle (#3848)
[client] Add latest gzipped rotated log file to the debug bundle
2025-05-21 17:50:54 +03:00
Viktor Liu
e806d9de38 [client] Fix legacy routes when connecting to management servers older than v0.30.0 (#3854) 2025-05-21 13:48:55 +02:00
Zoltan Papp
daa8380df9 [client] Feature/lazy connection (#3379)
With the lazy connection feature, the peer will connect to target peers on-demand. The trigger can be any IP traffic.

This feature can be enabled with the NB_ENABLE_EXPERIMENTAL_LAZY_CONN environment variable.

When the engine receives a network map, it binds a free UDP port for every remote peer, and the system configures WireGuard endpoints for these ports. When traffic appears on a UDP socket, the system removes this listener and starts the peer connection procedure immediately.

Key changes
Fix slow netbird status -d command
Move from engine.go file to conn_mgr.go the peer connection related code
Refactor the iface interface usage and moved interface file next to the engine code
Add new command line flag and UI option to enable feature
The peer.Conn struct is reusable after it has been closed.
Change connection states
Connection states
Idle: The peer is not attempting to establish a connection. This typically means it's in a lazy state or the remote peer is expired.

Connecting: The peer is actively trying to establish a connection. This occurs when the peer has entered an active state and is continuously attempting to reach the remote peer.

Connected: A successful peer-to-peer connection has been established and communication is active.
2025-05-21 11:12:28 +02:00
Bethuel Mmbaga
4785f23fc4 [management] Migrate events sqlite store to gorm (#3837) 2025-05-20 17:00:37 +03:00
Viktor Liu
1d4cfb83e7 [client] Fix UI new version notifier (#3845) 2025-05-20 10:39:17 +02:00
Pascal Fischer
207fa059d2 [management] make locking strength clause optional (#3844) 2025-05-19 16:42:47 +02:00
Viktor Liu
cbcdad7814 [misc] Update issue template (#3842) 2025-05-19 15:41:24 +02:00
Pascal Fischer
701c13807a [management] add flag to disable auto-migration (#3840) 2025-05-19 13:36:24 +02:00
Viktor Liu
99f8dc7748 [client] Offer to remove netbird data in windows uninstall (#3766) 2025-05-16 17:39:30 +02:00
Pascal Fischer
f1de8e6eb0 [management] Make startup period configurable (#3767) 2025-05-16 13:16:51 +02:00
Viktor Liu
b2a10780af [client] Disable dnssec for systemd explicitly (#3831) 2025-05-16 09:43:13 +02:00
Pascal Fischer
43ae79d848 [management] extend rest client lib (#3830) 2025-05-15 18:20:29 +02:00
Pascal Fischer
e520b64c6d [signal] remove stream receive server side (#3820) 2025-05-14 19:28:51 +02:00
hakansa
92c91bbdd8 [client] Add FreeBSD desktop client support to OAuth flow (#3822)
[client] Add FreeBSD desktop client support to OAuth flow
2025-05-14 19:52:02 +03:00
Vlad
adf494e1ac [management] fix a bug with missed extra dns labels for a new peer (#3798) 2025-05-14 17:50:21 +02:00
Vlad
2158461121 [management,client] PKCE add flag parameter prompt=login or max_age (#3824) 2025-05-14 17:48:51 +02:00
Bethuel Mmbaga
0cd4b601c3 [management] Add connection type filter to Network Traffic API (#3815) 2025-05-14 11:15:50 +03:00
223 changed files with 9138 additions and 4694 deletions

View File

@@ -37,17 +37,22 @@ If yes, which one?
**Debug output** **Debug output**
To help us resolve the problem, please attach the following debug output To help us resolve the problem, please attach the following anonymized status output
netbird status -dA netbird status -dA
As well as the file created by Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
If applicable, add screenshots to help explain your problem. If applicable, add screenshots to help explain your problem.
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?** **Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions - [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones) - [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client - [ ] Restarted the NetBird client
- [ ] Disabled other VPN software - [ ] Disabled other VPN software
- [ ] Checked firewall settings - [ ] Checked firewall settings

View File

@@ -13,3 +13,5 @@
- [ ] It is a refactor - [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible) - [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary - [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -223,6 +223,10 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -269,6 +273,10 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -21,7 +21,6 @@ jobs:
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1
golangci: golangci:
strategy: strategy:
fail-fast: false fail-fast: false

View File

@@ -172,13 +172,14 @@ jobs:
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values # check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$' grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445" grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"' grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true' grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy

View File

@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
return a.ipAnonymizer[ip] return a.ipAnonymizer[ip]
} }
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
// Convert IP to netip.Addr
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return addr
}
anonIP := a.AnonymizeIP(ip)
return net.UDPAddr{
IP: anonIP.AsSlice(),
Port: addr.Port,
Zone: addr.Zone,
}
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs // isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 { if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"runtime"
"strings" "strings"
"time" "time"
@@ -98,11 +99,11 @@ var loginCmd = &cobra.Command{
} }
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq, DnsLabels: dnsLabelsReq,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -195,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop()) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -243,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
} }
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
} }

View File

@@ -26,22 +26,22 @@ import (
) )
const ( const (
externalIPMapFlag = "external-ip-map" externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address" dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass" enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive" rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key" preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name" interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port" wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor" networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect" disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access" enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle" uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url" uploadBundleURL = "upload-bundle-url"
) )
var ( var (
@@ -77,9 +77,9 @@ var (
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
blockLANAccess bool
debugUploadBundle bool debugUploadBundle bool
debugUploadBundleURL string debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
@@ -184,6 +184,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"runtime"
"sync" "sync"
"github.com/kardianos/service" "github.com/kardianos/service"
@@ -27,12 +28,19 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
} }
func newSVCConfig() *service.Config { func newSVCConfig() *service.Config {
return &service.Config{ config := &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.", Description: "Netbird mesh network client",
Option: make(service.KeyValue), Option: make(service.KeyValue),
EnvVars: make(map[string]string),
} }
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
} }
func newSVC(prg *program, conf *service.Config) (service.Service, error) { func newSVC(prg *program, conf *service.Config) (service.Service, error) {

View File

@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
} }
if logFile != "console" { if logFile != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
} }

View File

@@ -44,7 +44,7 @@ func init() {
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected": case "", "idle", "connecting", "connected":
if strings.ToLower(statusFilter) != "" { if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
default: default:
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter) return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
} }
if len(ipsFilter) > 0 { if len(ipsFilter) > 0 {

View File

@@ -6,6 +6,8 @@ const (
disableServerRoutesFlag = "disable-server-routes" disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns" disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall" disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
) )
var ( var (
@@ -13,6 +15,8 @@ var (
disableServerRoutes bool disableServerRoutes bool
disableDNS bool disableDNS bool
disableFirewall bool disableFirewall bool
blockLANAccess bool
blockInbound bool
) )
func init() { func init() {
@@ -28,4 +32,11 @@ func init() {
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false, upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.") "Disable firewall configuration. If enabled, the client won't modify firewall rules.")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
} }

View File

@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: ` Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0 netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3), Args: cobra.ExactArgs(3),
RunE: tracePacket, RunE: tracePacket,

View File

@@ -55,12 +55,11 @@ func init() {
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
) )
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval") upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil, upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+ `Sets DNS labels`+
@@ -119,79 +118,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
ic := internal.ConfigInput{ ic, err := setupConfig(customDNSAddressConverted, cmd)
ManagementURL: managementURL, if err != nil {
AdminURL: adminURL, return fmt.Errorf("setup config: %v", err)
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
} }
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
@@ -199,7 +128,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(*ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
} }
@@ -258,21 +187,153 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
if err != nil { if err != nil {
return err return fmt.Errorf("get setup key: %v", err)
} }
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup login request: %v", err)
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
ic.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
return &ic, nil
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels, DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0, CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -297,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return err return nil, err
} }
loginRequest.InterfaceName = &interfaceName loginRequest.InterfaceName = &interfaceName
} }
@@ -332,45 +393,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.BlockLanAccess = &blockLANAccess loginRequest.BlockLanAccess = &blockLANAccess
} }
var loginErr error if cmd.Flag(blockInboundFlag).Changed {
loginRequest.BlockInbound = &blockInbound
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
} }
if loginErr != nil { if cmd.Flag(enableLazyConnectionFlag).Changed {
return fmt.Errorf("login failed: %v", loginErr) loginRequest.LazyConnectionEnabled = &lazyConnEnabled
} }
return &loginRequest, nil
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
} }
func validateNATExternalIPs(list []string) error { func validateNATExternalIPs(list []string) error {

View File

@@ -147,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -198,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil, nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
"all", firewall.ProtocolALL,
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
@@ -219,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }

View File

@@ -2,7 +2,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -248,10 +248,6 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -278,10 +274,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if pair.Masquerade { if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)

View File

@@ -116,6 +116,8 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering( AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,

View File

@@ -170,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -324,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }

View File

@@ -3,7 +3,6 @@ package nftables
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1").Unmap()
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: add.AsSlice(), Data: ip.AsSlice(),
}, },
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
}) })
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(

View File

@@ -573,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -1006,10 +1002,6 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }

View File

@@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip net.IP ip tcpip.Address
netstack bool netstack bool
} }
@@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("failed to create NIC: %v", err)
} }
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: ones, PrefixLen: iface.Address().Network.Bits(),
}, },
} }
@@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: iface.Address().IP, ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) { if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1) return net.IPv4(127, 0, 0, 1)
} }
return addr.AsSlice() return addr.AsSlice()
@@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
} }
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil { if errInToOut != nil {
if !isClosedError(errInToOut) { if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut) f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
} }
} }
if errOutToIn != nil { if errOutToIn != nil {
if !isClosedError(errOutToIn) { if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn) f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
} }
} }

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait() wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) { if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr) f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
} }
if inboundErr != nil && !isClosedError(inboundErr) { if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr) f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
} }
var rxPackets, txPackets uint64 var rxPackets, txPackets uint64

View File

@@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
} }
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if ipv4 := ip.To4(); ipv4 != nil { if !ip.Is4() {
high := uint16(ipv4[0]) return
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) }
ipv4 := ip.AsSlice()
if bitmap[high] == nil { high := uint16(ipv4[0])
bitmap[high] = &ipv4LowBitmap{} low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
}
index := low / 32 if bitmap[high] == nil {
bit := low % 32 bitmap[high] = &ipv4LowBitmap{}
bitmap[high].bitmap[index] |= 1 << bit }
ipStr := ipv4.String() index := low / 32
if _, exists := ipv4Set[ipStr]; !exists { bit := low % 32
ipv4Set[ipStr] = struct{}{} bitmap[high].bitmap[index] |= 1 << bit
*ipv4Addresses = append(*ipv4Addresses, ipStr)
} if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
} }
} }
@@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
} }
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil return nil
} }
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue continue
} }
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil { addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err) log.Debugf("process IP failed: %v", err)
} }
} }
@@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}() }()
var newIPv4Bitmap [256]*ipv4LowBitmap var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{}) ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []string var ipv4Addresses []netip.Addr
// 127.0.0.0/8 // 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{} newIPv4Bitmap[127] = &ipv4LowBitmap{}

View File

@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
@@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
@@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match - addresses 32 apart", name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.33"), testIP: netip.MustParseAddr("192.168.1.33"),
expected: false, expected: false,
@@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: netip.MustParseAddr("fe80::1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,

View File

@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@@ -39,8 +39,12 @@ const (
// EnvForceUserspaceRouter forces userspace routing even if native routing is available. // EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack // EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible // Default off as it might be security risk because sockets listening on localhost only will become accessible.
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
@@ -71,7 +75,6 @@ type Manager struct {
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@@ -148,6 +151,11 @@ func parseCreateEnv() (bool, bool) {
if err != nil { if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
} }
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
} }
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
@@ -269,7 +277,7 @@ func (m *Manager) determineRouting() error {
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported(): case !m.netstack && m.nativeFirewall != nil:
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
@@ -326,6 +334,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
@@ -606,9 +618,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true return true
} }
if m.stateful { // for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, size)
}
return false return false
} }
@@ -777,9 +788,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true return true
} }
// if running in netstack mode we need to pass this to the forwarder // If requested we pass local traffic to internal interfaces to the forwarder.
if m.netstack && m.localForwarding { // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
return m.handleNetstackLocalTraffic(packetData) if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
return m.handleForwardedLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
@@ -789,8 +801,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return false return false
} }
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool { func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
fwd := m.forwarder.Load() fwd := m.forwarder.Load()
if fwd == nil { if fwd == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
@@ -1088,11 +1099,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true return true
} }
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not

View File

@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup // Apply scenario-specific setup
sc.setupFunc(manager) sc.setupFunc(manager)
@@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table // Pre-populate connection table
srcIPs := generateRandomIPs(count) srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count)
@@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0] srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "post_handshake", state: "post_handshake",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err) require.NoError(b, err)
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept) dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@@ -19,12 +19,8 @@ import (
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
localIP := net.ParseIP("100.10.0.100") localIP := netip.MustParseAddr("100.10.0.100")
wgNet := &net.IPNet{ wgNet := netip.MustParsePrefix("100.10.0.0/16")
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
@@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl) dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes() dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
localIP, wgNet, err := net.ParseCIDR(network) wgNet := netip.MustParsePrefix(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: localIP, IP: wgNet.Addr(),
Network: wgNet, Network: wgNet,
} }
}, },
@@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
@@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
@@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{

View File

@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) { if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }

View File

@@ -12,6 +12,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var zeroKey wgtypes.Key
type KernelConfigurer struct { type KernelConfigurer struct {
deviceName string deviceName string
} }
@@ -201,14 +203,71 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() { func (c *KernelConfigurer) Close() {
} }
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) { func (c *KernelConfigurer) FullStats() (*Stats, error) {
peer, err := c.getPeer(c.deviceName, peerKey) wg, err := wgctrl.New()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) return nil, fmt.Errorf("wgctl: %w", err)
} }
return WGStats{ defer func() {
LastHandshake: peer.LastHandshakeTime, err = wg.Close()
TxBytes: peer.TransmitBytes, if err != nil {
RxBytes: peer.ReceiveBytes, log.Errorf("Got error while closing wgctl: %v", err)
}, nil }
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
fullStats := &Stats{
DeviceName: wgDevice.Name,
PublicKey: wgDevice.PublicKey.String(),
ListenPort: wgDevice.ListenPort,
FWMark: wgDevice.FirewallMark,
Peers: []Peer{},
}
for _, p := range wgDevice.Peers {
peer := Peer{
PublicKey: p.PublicKey.String(),
AllowedIPs: p.AllowedIPs,
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint
}
fullStats.Peers = append(fullStats.Peers, peer)
}
return fullStats, nil
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}
}
return stats, nil
} }

View File

@@ -1,6 +1,7 @@
package configurer package configurer
import ( import (
"encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
@@ -17,6 +18,20 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct { type WGUSPConfigurer struct {
@@ -178,6 +193,15 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
ipcStr, err := c.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("IpcGet failed: %w", err)
}
return parseStatus(c.deviceName, ipcStr)
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() { func (t *WGUSPConfigurer) startUAPI() {
var err error var err error
@@ -217,91 +241,75 @@ func (t *WGUSPConfigurer) Close() {
} }
} }
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) { func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
ipc, err := t.device.IpcGet() ipc, err := t.device.IpcGet()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err) return nil, fmt.Errorf("ipc get: %w", err)
} }
stats, err := findPeerInfo(ipc, peerKey, []string{ return parseTransfers(ipc)
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
} }
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) { func parseTransfers(ipc string) (map[string]WGStats, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) stats := make(map[string]WGStats)
if err != nil { var (
return nil, fmt.Errorf("parse key: %w", err) currentKey string
} currentStats WGStats
hasPeer bool
hexKey := hex.EncodeToString(peerKeyParsed[:]) )
lines := strings.Split(ipc, "\n")
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key, // If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop. // this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") && foundPeer { if strings.HasPrefix(line, "public_key=") {
break peerID := strings.TrimPrefix(line, "public_key=")
} h, err := hex.DecodeString(peerID)
if err != nil {
// Identify the peer with the specific public key return nil, fmt.Errorf("decode peerID: %w", err)
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
for _, key := range searchConfigKeys {
if foundPeer && strings.HasPrefix(line, key+"=") {
v := strings.SplitN(line, "=", 2)
configFound[v[0]] = v[1]
} }
currentKey = base64.StdEncoding.EncodeToString(h)
currentStats = WGStats{} // Reset stats for the new peer
hasPeer = true
stats[currentKey] = currentStats
continue
}
if !hasPeer {
continue
}
key := strings.SplitN(line, "=", 2)
if len(key) != 2 {
continue
}
switch key[0] {
case ipcKeyLastHandshakeTimeSec:
hs, err := toLastHandshake(key[1])
if err != nil {
return nil, err
}
currentStats.LastHandshake = hs
stats[currentKey] = currentStats
case ipcKeyRxBytes:
rxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
currentStats.RxBytes = rxBytes
stats[currentKey] = currentStats
case ipcKeyTxBytes:
TxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
currentStats.TxBytes = TxBytes
stats[currentKey] = currentStats
} }
} }
// todo: use multierr return stats, nil
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
} }
func toWgUserspaceString(wgCfg wgtypes.Config) string { func toWgUserspaceString(wgCfg wgtypes.Config) string {
@@ -355,9 +363,154 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
return sb.String() return sb.String()
} }
func toLastHandshake(stringVar string) (time.Time, error) {
sec, err := strconv.ParseInt(stringVar, 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
return time.Unix(sec, 0), nil
}
func toBytes(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark return nbnet.ControlPlaneMark
} }
return 0 return 0
} }
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
// Decode hex string to bytes
keyBytes, err := hex.DecodeString(hexKey)
if err != nil {
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
}
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
if len(keyBytes) != 32 {
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
}
// Convert to wgtypes.Key
var key wgtypes.Key
copy(key[:], keyBytes)
return key, nil
}
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
stats := &Stats{DeviceName: deviceName}
var currentPeer *Peer
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
val := parts[1]
switch key {
case privateKey:
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
continue
}
stats.PublicKey = key.PublicKey().String()
case publicKey:
// Save previous peer
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse public key: %v", err)
continue
}
currentPeer = &Peer{
PublicKey: key.String(),
}
case listenPort:
if port, err := strconv.Atoi(val); err == nil {
stats.ListenPort = port
}
case fwmark:
if fwmark, err := strconv.Atoi(val); err == nil {
stats.FWMark = fwmark
}
case endpoint:
if currentPeer == nil {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Errorf("failed to parse endpoint port: %v", err)
continue
}
currentPeer.Endpoint = net.UDPAddr{
IP: net.ParseIP(host),
Port: port,
}
case allowedIP:
if currentPeer == nil {
continue
}
_, ipnet, err := net.ParseCIDR(val)
if err == nil {
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
}
case ipcKeyTxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.TxBytes = rxBytes
case ipcKeyRxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.RxBytes = rxBytes
case ipcKeyLastHandshakeTimeSec:
if currentPeer == nil {
continue
}
ts, err := toLastHandshake(val)
if err != nil {
continue
}
currentPeer.LastHandshake = ts
case presharedKey:
if currentPeer == nil {
continue
}
if val != "" {
currentPeer.PresharedKey = true
}
}
}
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
return stats, nil
}

View File

@@ -2,10 +2,8 @@ package configurer
import ( import (
"encoding/hex" "encoding/hex"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -34,58 +32,35 @@ errno=0
` `
func Test_findPeerInfo(t *testing.T) { func Test_parseTransfers(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
peerKey string peerKey string
searchKeys []string want WGStats
want map[string]string
wantErr bool
}{ }{
{ {
name: "single", name: "single",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
searchKeys: []string{"tx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 0,
"tx_bytes": "38333", RxBytes: 0,
}, },
wantErr: false,
}, },
{ {
name: "multiple", name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes", "rx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 38333,
"tx_bytes": "38333", RxBytes: 2224,
"rx_bytes": "2224",
}, },
wantErr: false,
}, },
{ {
name: "lastpeer", name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58", peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "rx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 1212111,
"tx_bytes": "1212111", RxBytes: 1929999999,
"rx_bytes": "1929999999",
}, },
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
key, err := wgtypes.NewKey(res) key, err := wgtypes.NewKey(res)
require.NoError(t, err) require.NoError(t, err)
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys) stats, err := parseTransfers(ipcFixture)
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)) if err != nil {
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys) require.NoError(t, err)
return
}
stat, ok := stats[key.String()]
if !ok {
require.True(t, ok)
return
}
require.Equal(t, tt.want, stat)
}) })
} }
} }

View File

@@ -0,0 +1,24 @@
package configurer
import (
"net"
"time"
)
type Peer struct {
PublicKey string
Endpoint net.UDPAddr
AllowedIPs []net.IPNet
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
}
type Stats struct {
DeviceName string
PublicKey string
ListenPort int
FWMark int
Peers []Peer
}

View File

@@ -1,7 +1,6 @@
package device package device
import ( import (
"net"
"net/netip" "net/netip"
"sync" "sync"
@@ -24,9 +23,6 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
} }
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets

View File

@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1) dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@@ -16,5 +16,6 @@ type WGConfigurer interface {
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error
Close() Close()
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
} }

View File

@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
} }
ip := address.IP.String() ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
} }
w.filter = filter w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter) w.tun.FilteredDevice().SetFilter(filter)
return nil return nil
@@ -212,9 +211,13 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device() return w.tun.Device()
} }
// GetStats returns the last handshake time, rx and tx bytes for the given peer // GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats(peerKey) return w.configurer.GetStats()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
} }
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {

View File

@@ -5,7 +5,6 @@
package mocks package mocks
import ( import (
net "net"
"net/netip" "net/netip"
reflect "reflect" reflect "reflect"
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
} }
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -1,8 +1,6 @@
package netstack package netstack
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"strconv" "strconv"
@@ -15,8 +13,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address net.IP address netip.Addr
dnsAddress net.IP dnsAddress netip.Addr
mtu int mtu int
listenAddress string listenAddress string
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, address: address,
dnsAddress: dnsAddress, dnsAddress: dnsAddress,
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
} }
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{addr.Unmap()}, []netip.Addr{t.address},
[]netip.Addr{dnsAddr.Unmap()}, []netip.Addr{t.dnsAddress},
t.mtu) t.mtu)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -2,28 +2,27 @@ package wgaddr
import ( import (
"fmt" "fmt"
"net" "net/netip"
) )
// Address WireGuard parsed address // Address WireGuard parsed address
type Address struct { type Address struct {
IP net.IP IP netip.Addr
Network *net.IPNet Network netip.Prefix
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) prefix, err := netip.ParsePrefix(address)
if err != nil { if err != nil {
return Address{}, err return Address{}, err
} }
return Address{ return Address{
IP: ip, IP: prefix.Addr().Unmap(),
Network: network, Network: prefix.Masked(),
}, nil }, nil
} }
func (addr Address) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -24,6 +24,8 @@
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run" !define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -49,6 +51,10 @@ ShowInstDetails Show
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!include "nsDialogs.nsh"
!define MUI_ICON "${ICON}" !define MUI_ICON "${ICON}"
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
@@ -58,9 +64,6 @@ ShowInstDetails Show
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
!define MUI_UNABORTWARNING !define MUI_UNABORTWARNING
@@ -70,13 +73,16 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
!insertmacro MUI_UNPAGE_WELCOME
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
!insertmacro MUI_UNPAGE_CONFIRM !insertmacro MUI_UNPAGE_CONFIRM
!insertmacro MUI_UNPAGE_INSTFILES !insertmacro MUI_UNPAGE_INSTFILES
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
Var AutostartCheckbox Var AutostartCheckbox
Var AutostartEnabled Var AutostartEnabled
; Variables for uninstall data deletion option
Var DeleteDataCheckbox
Var DeleteDataEnabled
###################################################################### ######################################################################
; Function to create the autostart options page ; Function to create the autostart options page
@@ -104,8 +114,8 @@ Function AutostartPage
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts" ${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked ${NSD_Check} $AutostartCheckbox
StrCpy $AutostartEnabled "1" ; Default to enabled StrCpy $AutostartEnabled "1"
nsDialogs::Show nsDialogs::Show
FunctionEnd FunctionEnd
@@ -115,6 +125,30 @@ Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled ${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd FunctionEnd
; Function to create the uninstall data deletion page
Function un.DeleteDataPage
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
Pop $DeleteDataCheckbox
${NSD_Uncheck} $DeleteDataCheckbox
StrCpy $DeleteDataEnabled "0"
nsDialogs::Show
FunctionEnd
; Function to handle leaving the data deletion page
Function un.DeleteDataPageLeave
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -176,10 +210,10 @@ ${EndIf}
FunctionEnd FunctionEnd
###################################################################### ######################################################################
Section -MainProgram Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
# SetOverwrite ifnewer # SetOverwrite ifnewer
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\" File /r "..\\dist\\netbird_windows_amd64\\"
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -225,31 +259,58 @@ SectionEnd
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
DetailPrint "Stopping Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
DetailPrint "Uninstalling Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."
${If} $DeleteDataEnabled == "1"
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
ClearErrors
RMDir /r "${NETBIRD_DATA_DIR}"
IfErrors 0 +2 ; If no errors, jump over the message
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
DetailPrint "Netbird data directory removal complete."
${Else}
DetailPrint "User did not opt to delete Netbird data."
${EndIf}
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
DetailPrint "Waiting for service handle to be released..."
Sleep 3000 Sleep 3000
DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
DetailPrint "Removing shortcuts..."
SetShellVarContext all SetShellVarContext all
Delete "$DESKTOP\${APP_NAME}.lnk" Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk" Delete "$SMPROGRAMS\${APP_NAME}.lnk"
DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM
EnVar::DeleteValue "path" "$INSTDIR" EnVar::DeleteValue "path" "$INSTDIR"
DetailPrint "Uninstallation finished."
SectionEnd SectionEnd

View File

@@ -58,6 +58,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
@@ -69,20 +74,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total) time.Since(start), total)
}() }()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)
} }
@@ -291,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete. if d.firewall.IsStateful() {
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")

View File

@@ -1,13 +1,14 @@
package acl package acl
import ( import (
"net" "net/netip"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { require.NoError(t, err)
t.Errorf("create firewall: %v", err) defer func() {
return err = fw.Close(nil)
} require.NoError(t, err)
defer func(fw manager.Manager) { }()
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 2 { if fw.IsStateful() {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) assert.Equal(t, 0, len(acl.peerRulesPairs))
return } else {
assert.Equal(t, 2, len(acl.peerRulesPairs))
} }
}) })
@@ -94,12 +91,13 @@ func TestDefaultManager(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
// we should have one old and one new rule in the existed rules expectedRules := 2
if len(acl.peerRulesPairs) != 2 { if fw.IsStateful() {
t.Errorf("firewall rules not applied") expectedRules = 1 // only the inbound rule
return
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
previousCount++ previousCount++
} }
} }
if previousCount != 1 {
t.Errorf("old rule was not removed") expectedPreviousCount := 0
if !fw.IsStateful() {
expectedPreviousCount = 1
} }
assert.Equal(t, expectedPreviousCount, previousCount)
}) })
t.Run("handle default rules", func(t *testing.T) { t.Run("handle default rules", func(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 { acl.ApplyFiltering(networkMap, false)
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) assert.Equal(t, 0, len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) expectedRules := 1
return if fw.IsStateful() {
expectedRules = 1 // only inbound allow-all rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
})
}
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
// In stateless mode, we should have both inbound and outbound rules
assert.False(t, fw.IsStateful())
assert.Equal(t, 2, len(acl.peerRulesPairs))
}) })
} }
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
manager := &DefaultManager{} manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap) rules, _ := manager.squashAcceptRules(networkMap)
if len(rules) != 2 { assert.Equal(t, 2, len(rules))
t.Errorf("rules should contain 2, got: %v", rules)
return
}
r := rules[0] r := rules[0]
switch { assert.Equal(t, "0.0.0.0", r.PeerIP)
case r.PeerIP != "0.0.0.0": assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
return assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1] r = rules[1]
switch { assert.Equal(t, "0.0.0.0", r.PeerIP)
case r.PeerIP != "0.0.0.0": assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
return assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
} }
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
@@ -291,9 +326,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
} }
manager := &DefaultManager{} manager := &DefaultManager{}
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) { rules, _ := manager.squashAcceptRules(networkMap)
t.Errorf("we should get the same amount of rules as output, got %v", len(rules)) assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
} }
func TestDefaultManagerEnableSSHRules(t *testing.T) { func TestDefaultManagerEnableSSHRules(t *testing.T) {
@@ -336,33 +370,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { require.NoError(t, err)
t.Errorf("create firewall: %v", err) defer func() {
return err = fw.Close(nil)
} require.NoError(t, err)
defer func(fw manager.Manager) { }()
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 3 { expectedRules := 3
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) if fw.IsStateful() {
return expectedRules = 3 // 2 inbound rules + SSH rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
} }

View File

@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxDesktopClient { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }

View File

@@ -101,7 +101,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
} }
if !p.providerConfig.DisablePromptLogin { if !p.providerConfig.DisablePromptLogin {
params = append(params, oauth2.SetAuthURLParam("prompt", "login")) if p.providerConfig.LoginFlag.IsPromptLogin() {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
} }
authURL := p.oAuthConfig.AuthCodeURL(state, params...) authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@@ -7,15 +7,36 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/management/client/common"
) )
func TestPromptLogin(t *testing.T) { func TestPromptLogin(t *testing.T) {
const (
promptLogin = "prompt=login"
maxAge0 = "max_age=0"
)
tt := []struct { tt := []struct {
name string name string
prompt bool loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
}{ }{
{"PromptLogin", true}, {
{"NoPromptLogin", false}, name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
disablePromptLogin: true,
},
} }
for _, tc := range tt { for _, tc := range tt {
@@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize", AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"}, RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true, UseIDToken: true,
DisablePromptLogin: !tc.prompt, LoginFlag: tc.loginFlag,
} }
pkce, err := NewPKCEAuthorizationFlow(config) pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil { if err != nil {
@@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to request auth info: %v", err) t.Fatalf("Failed to request auth info: %v", err)
} }
pattern := "prompt=login"
if tc.prompt { if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, pattern) require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else { } else {
require.NotContains(t, authInfo.VerificationURIComplete, pattern) require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
} }
}) })
} }

View File

@@ -68,12 +68,14 @@ type ConfigInput struct {
DisableServerRoutes *bool DisableServerRoutes *bool
DisableDNS *bool DisableDNS *bool
DisableFirewall *bool DisableFirewall *bool
BlockLANAccess *bool
BlockLANAccess *bool BlockInbound *bool
DisableNotifications *bool DisableNotifications *bool
DNSLabels domain.List DNSLabels domain.List
LazyConnectionEnabled *bool
} }
// Config Configuration type // Config Configuration type
@@ -96,8 +98,8 @@ type Config struct {
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool
BlockLANAccess bool BlockInbound bool
DisableNotifications *bool DisableNotifications *bool
@@ -138,6 +140,8 @@ type Config struct {
ClientCertKeyPath string ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"` ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -479,6 +483,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
if *input.BlockInbound {
log.Infof("blocking inbound connections")
} else {
log.Infof("allowing inbound connections")
}
config.BlockInbound = *input.BlockInbound
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications { if *input.DisableNotifications {
log.Infof("disabling notifications") log.Infof("disabling notifications")
@@ -524,6 +538,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
return updated, nil return updated, nil
} }

303
client/internal/conn_mgr.go Normal file
View File

@@ -0,0 +1,303 @@
package internal
import (
"context"
"os"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
//
// The connection manager is responsible for:
// - Managing lazy connections via the lazyConnManager
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
lazyConnMgr *manager.Manager
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
}
return e
}
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
func (e *ConnMgr) Start(ctx context.Context) {
if e.lazyConnMgr != nil {
log.Errorf("lazy connection manager is already started")
return
}
if !e.enabledLocally {
log.Infof("lazy connection manager is disabled")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
// do not disable lazy connection manager if it was enabled by env var
if e.enabledLocally {
return nil
}
if enabled {
// if the lazy connection manager is already started, do not start it again
if e.lazyConnMgr != nil {
return nil
}
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager(ctx)
} else {
if e.lazyConnMgr == nil {
return nil
}
log.Infof("lazy connection manager is disabled by management feature flag")
e.closeManager(ctx)
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) {
if e.lazyConnMgr == nil {
return
}
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
for peerID := range peerIDs {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
continue
}
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
if err := peerConn.Open(e.ctx); err != nil {
peerConn.Log.Errorf("failed to open connection: %v", err)
}
}
}
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
return true
}
if !e.isStartedWithLazyMgr() {
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if !lazyconn.IsSupported(conn.AgentVersionString()) {
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerKey,
AllowedIPs: conn.WgConfig().AllowedIps,
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if excluded {
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
conn.Log.Infof("peer added to lazy conn manager")
return
}
func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn, ok := e.peerStore.Remove(peerKey)
if !ok {
return
}
defer conn.Close()
if !e.isStartedWithLazyMgr() {
return
}
e.lazyConnMgr.RemovePeer(peerKey)
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if !e.isStartedWithLazyMgr() {
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(e.ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
return conn, true
}
func (e *ConnMgr) Close() {
if !e.isStartedWithLazyMgr() {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
}
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
ctx, cancel := context.WithCancel(parentCtx)
e.ctx = ctx
e.ctxCancel = cancel
e.wg.Add(1)
go func() {
defer e.wg.Done()
e.lazyConnMgr.Start(ctx)
}()
}
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
peers := e.peerStore.PeersPubKey()
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
for _, peerID := range peers {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {
if e.lazyConnMgr == nil {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
for _, peerID := range e.peerStore.PeersPubKey() {
e.peerStore.PeerConnOpen(ctx, peerID)
}
}
func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.ctxCancel != nil
}
func inactivityThresholdEnv() *time.Duration {
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
if envValue == "" {
return nil
}
parsedMinutes, err := strconv.Atoi(envValue)
if err != nil || parsedMinutes <= 0 {
return nil
}
d := time.Duration(parsedMinutes) * time.Minute
return &d
}

View File

@@ -436,11 +436,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DNSRouteInterval: config.DNSRouteInterval, DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes, DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes, DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
DisableDNS: config.DisableDNS, DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall, DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess,
BlockInbound: config.BlockInbound,
BlockLANAccess: config.BlockLANAccess, LazyConnectionEnabled: config.LazyConnectionEnabled,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -481,7 +483,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
return signalClient, nil return signalClient, nil
} }
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) { func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey() serverPublicKey, err := client.GetServerPublicKey()

View File

@@ -4,6 +4,7 @@ import (
"archive/zip" "archive/zip"
"bufio" "bufio"
"bytes" "bytes"
"compress/gzip"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -269,11 +270,21 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err) log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
} }
if g.logFile != "console" { if err := g.addWgShow(); err != nil {
if err := g.addLogfile(); err != nil { log.Errorf("Failed to add wg show output: %v", err)
return fmt.Errorf("add log file: %w", err)
}
} }
if g.logFile != "console" && g.logFile != "" {
if err := g.addLogfile(); err != nil {
log.Errorf("Failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs: %v", err)
}
return nil return nil
} }
@@ -365,17 +376,34 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled)) configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive)) configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
if g.internalConfig.ServerSSHAllowed != nil { if g.internalConfig.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed)) configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
} }
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS)) configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall)) configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess)) configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
if g.internalConfig.DisableNotifications != nil {
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
}
configContent.WriteString(fmt.Sprintf("DNSLabels: %v\n", g.internalConfig.DNSLabels))
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
if g.internalConfig.ClientCertPath != "" {
configContent.WriteString(fmt.Sprintf("ClientCertPath: %s\n", g.internalConfig.ClientCertPath))
}
if g.internalConfig.ClientCertKeyPath != "" {
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {
@@ -533,6 +561,33 @@ func (g *BundleGenerator) addLogfile() error {
return fmt.Errorf("add client log file to zip: %w", err) return fmt.Errorf("add client log file to zip: %w", err)
} }
// add latest rotated log file
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
} else if len(files) > 0 {
// pick the file with the latest ModTime
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().Before(fj.ModTime())
})
latest := files[len(files)-1]
name := filepath.Base(latest)
if err := g.addSingleLogFileGz(latest, name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
stdErrLogPath := filepath.Join(logDir, errorLogFile) stdErrLogPath := filepath.Join(logDir, errorLogFile)
stdoutLogPath := filepath.Join(logDir, stdoutLogFile) stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" {
@@ -563,16 +618,13 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
} }
}() }()
var logReader io.Reader var logReader io.Reader = logFile
if g.anonymize { if g.anonymize {
var writer *io.PipeWriter var writer *io.PipeWriter
logReader, writer = io.Pipe() logReader, writer = io.Pipe()
go anonymizeLog(logFile, writer, g.anonymizer) go anonymizeLog(logFile, writer, g.anonymizer)
} else {
logReader = logFile
} }
if err := g.addFileToZip(logReader, targetName); err != nil { if err := g.addFileToZip(logReader, targetName); err != nil {
return fmt.Errorf("add %s to zip: %w", targetName, err) return fmt.Errorf("add %s to zip: %w", targetName, err)
} }
@@ -580,6 +632,44 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
return nil return nil
} }
// addSingleLogFileGz adds a single gzipped log file to the archive
func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
f, err := os.Open(logPath)
if err != nil {
return fmt.Errorf("open gz log file %s: %w", targetName, err)
}
defer f.Close()
gzr, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("create gzip reader: %w", err)
}
defer gzr.Close()
var logReader io.Reader = gzr
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(gzr, pw, g.anonymizer)
}
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
if _, err := io.Copy(gw, logReader); err != nil {
return fmt.Errorf("re-gzip: %w", err)
}
if err := gw.Close(); err != nil {
return fmt.Errorf("close gzip writer: %w", err)
}
if err := g.addFileToZip(&buf, targetName); err != nil {
return fmt.Errorf("add anonymized gz: %w", err)
}
return nil
}
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error { func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
header := &zip.FileHeader{ header := &zip.FileHeader{
Name: filename, Name: filename,

View File

@@ -4,17 +4,104 @@ package debug
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sort" "sort"
"strings" "strings"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
maxLogEntries = 100000
maxLogAge = 7 * 24 * time.Hour // Last 7 days
)
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
func (g *BundleGenerator) trySystemdLogFallback() error {
log.Debug("Attempting to collect systemd journal logs")
serviceName := getServiceName()
journalLogs, err := getSystemdLogs(serviceName)
if err != nil {
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
}
if strings.Contains(journalLogs, "No recent log entries found") {
log.Debug("No recent log entries found in systemd journal")
return nil
}
if g.anonymize {
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
}
logReader := strings.NewReader(journalLogs)
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
if err := g.addFileToZip(logReader, fileName); err != nil {
return fmt.Errorf("add systemd logs to bundle: %w", err)
}
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
return nil
}
// getServiceName gets the service name from environment or defaults to netbird
func getServiceName() string {
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
return unitName
}
return "netbird"
}
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
func getSystemdLogs(serviceName string) (string, error) {
args := []string{
"-u", fmt.Sprintf("%s.service", serviceName),
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
"--lines", fmt.Sprintf("%d", maxLogEntries),
"--no-pager",
"--output", "short-iso",
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "journalctl", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
}
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("journalctl command not found: %w", err)
}
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
}
logs := stdout.String()
if strings.TrimSpace(logs) == "" {
return "No recent log entries found in systemd journal", nil
}
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
serviceName, maxLogEntries, maxLogAge.String())
return header + logs, nil
}
// addFirewallRules collects and adds firewall rules to the archive // addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error { func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules") log.Info("Collecting firewall rules")
@@ -481,7 +568,7 @@ func formatExpr(exp expr.Any) string {
case *expr.Fib: case *expr.Fib:
return formatFib(e) return formatFib(e)
case *expr.Target: case *expr.Target:
return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets return fmt.Sprintf("jump %s", e.Name)
case *expr.Immediate: case *expr.Immediate:
if e.Register == 1 { if e.Register == 1 {
return formatImmediateData(e.Data) return formatImmediateData(e.Data)

View File

@@ -6,3 +6,9 @@ package debug
func (g *BundleGenerator) addFirewallRules() error { func (g *BundleGenerator) addFirewallRules() error {
return nil return nil
} }
func (g *BundleGenerator) trySystemdLogFallback() error {
// Systemd is only available on Linux
// TODO: Add BSD support
return nil
}

View File

@@ -0,0 +1,66 @@
package debug
import (
"bytes"
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGIface interface {
FullStats() (*configurer.Stats, error)
}
func (g *BundleGenerator) addWgShow() error {
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err
}
output := g.toWGShowFormat(result)
reader := bytes.NewReader([]byte(output))
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
return fmt.Errorf("add wg show to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
if s.FWMark != 0 {
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
}
for _, peer := range s.Peers {
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
if peer.Endpoint.IP != nil {
if g.anonymize {
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
} else {
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
}
}
if len(peer.AllowedIPs) > 0 {
var ipStrings []string
for _, ipnet := range peer.AllowedIPs {
ipStrings = append(ipStrings, ipnet.String())
}
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
sb.WriteString(" preshared key: (hidden)\n")
}
}
return sb.String()
}

View File

@@ -2,7 +2,7 @@ package internal
import ( import (
"fmt" "fmt"
"net" "net/netip"
"slices" "slices"
"strings" "strings"
@@ -12,13 +12,14 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) { func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData) ip, err := netip.ParseAddr(aRecord.RData)
if ip == nil || ip.To4() == nil { if err != nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
if !ipNet.Contains(ip) { if !prefix.Contains(ip) {
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
@@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
} }
// generateReverseZoneName creates the reverse DNS zone name for a given network // generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) { func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask) networkIP := network.Masked().Addr()
maskOnes, _ := ipNet.Mask.Size()
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
// round up to nearest byte // round up to nearest byte
octetsToUse := (maskOnes + 7) / 8 octetsToUse := (network.Bits() + 7) / 8
octets := strings.Split(networkIP.String(), ".") octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) { if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes) return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
} }
reverseOctets := make([]string, octetsToUse) reverseOctets := make([]string, octetsToUse)
@@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
} }
// collectPTRRecords gathers all PTR records for the given network from A records // collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord { func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones { for _, zone := range config.CustomZones {
@@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
continue continue
} }
if ptrRecord, ok := createPTRRecord(record, ipNet); ok { if ptrRecord, ok := createPTRRecord(record, prefix); ok {
records = append(records, ptrRecord) records = append(records, ptrRecord)
} }
} }
@@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
} }
// addReverseZone adds a reverse DNS zone to the configuration for the given network // addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { func addReverseZone(config *nbdns.Config, network netip.Prefix) {
zoneName, err := generateReverseZoneName(ipNet) zoneName, err := generateReverseZoneName(network)
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
return return
@@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
return return
} }
records := collectPTRRecords(config, ipNet) records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{ reverseZone := nbdns.CustomZone{
Domain: zoneName, Domain: zoneName,

View File

@@ -1,6 +1,7 @@
package dns package dns
import ( import (
"fmt"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -148,61 +149,42 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
qname := strings.ToLower(r.Question[0].Name) qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock() c.mu.RLock()
handlers := slices.Clone(c.handlers) handlers := slices.Clone(c.handlers)
c.mu.RUnlock() c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) { if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers)) var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers { for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority) h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
} }
log.Trace(strings.TrimSuffix(b.String(), "\n"))
} }
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
var matched bool matched := c.isHandlerMatch(qname, entry)
switch {
case entry.Pattern == ".": if matched {
matched = true log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
case entry.IsWildcard: qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) chainWriter := &ResponseWriterChain{
default: ResponseWriter: w,
// For non-wildcard patterns: origPattern: entry.OrigPattern,
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = strings.EqualFold(qname, entry.Pattern)
} }
} entry.Handler.ServeDNS(chainWriter, r)
if !matched { // If handler wants to continue, try next handler
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false", if chainWriter.shouldContinue {
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority) log.Tracef("handler requested continue to next handler for domain=%s", qname)
continue continue
}
return
} }
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler")
continue
}
return
} }
// No handler matched or all handlers passed // No handler matched or all handlers passed
@@ -213,3 +195,22 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
} }
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
return strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
return strings.EqualFold(qname, entry.Pattern)
}
}
}

View File

@@ -1,11 +1,14 @@
package dns package dns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os/exec"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -41,6 +44,20 @@ const (
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
registrationEnabledKey = "RegistrationEnabled"
maxNumberOfAddressesToRegisterKey = "MaxNumberOfAddressesToRegister"
// NetBIOS/WINS settings
netbtInterfacePath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces`
netbiosOptionsKey = "NetbiosOptions"
// NetBIOS option values: 0 = from DHCP, 1 = enabled, 2 = disabled
netbiosFromDHCP = 0
netbiosEnabled = 1
netbiosDisabled = 2
// RP_FORCE: Reapply all policies even if no policy change was detected // RP_FORCE: Reapply all policies even if no policy change was detected
rpForce = 0x1 rpForce = 0x1
) )
@@ -67,16 +84,85 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
log.Infof("detected GPO DNS policy configuration, using policy store") log.Infof("detected GPO DNS policy configuration, using policy store")
} }
return &registryConfigurator{ configurator := &registryConfigurator{
guid: guid, guid: guid,
gpo: useGPO, gpo: useGPO,
}, nil }
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
return configurator, nil
} }
func (r *registryConfigurator) supportCustomPort() bool { func (r *registryConfigurator) supportCustomPort() bool {
return false return false
} }
func (r *registryConfigurator) configureInterface() error {
var merr *multierror.Error
if err := r.disableDNSRegistrationForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable DNS registration: %w", err))
}
if err := r.disableWINSForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable WINS: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableDNSRegistrationForInterface() error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closer(regKey)
var merr *multierror.Error
if err := regKey.SetDWordValue(disableDynamicUpdateKey, 1); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", disableDynamicUpdateKey, err))
}
if err := regKey.SetDWordValue(registrationEnabledKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", registrationEnabledKey, err))
}
if err := regKey.SetDWordValue(maxNumberOfAddressesToRegisterKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", maxNumberOfAddressesToRegisterKey, err))
}
if merr == nil || len(merr.Errors) == 0 {
log.Infof("disabled DNS registration for interface %s", r.guid)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableWINSForInterface() error {
netbtKeyPath := fmt.Sprintf(`%s\Tcpip_%s`, netbtInterfacePath, r.guid)
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
regKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create NetBT interface key %s: %w", netbtKeyPath, err)
}
}
defer closer(regKey)
// NetbiosOptions: 2 = disabled
if err := regKey.SetDWordValue(netbiosOptionsKey, netbiosDisabled); err != nil {
return fmt.Errorf("set %s: %w", netbiosOptionsKey, err)
}
log.Infof("disabled WINS/NetBIOS for interface %s", r.guid)
return nil
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if config.RouteAll { if config.RouteAll {
if err := r.addDNSSetupForAll(config.ServerIP); err != nil { if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
@@ -119,9 +205,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
} }
if err := r.flushDNSCache(); err != nil { go r.flushDNSCache()
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }
@@ -191,7 +275,25 @@ func (r *registryConfigurator) string() string {
return "registry" return "registry"
} }
func (r *registryConfigurator) flushDNSCache() error { func (r *registryConfigurator) registerDNS() {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
// nolint:misspell
cmd := exec.CommandContext(ctx, "ipconfig", "/registerdns")
out, err := cmd.CombinedOutput()
if err != nil {
log.Errorf("failed to register DNS: %v, output: %s", err, out)
return
}
log.Info("registered DNS names")
}
func (r *registryConfigurator) flushDNSCache() {
r.registerDNS()
// dnsFlushResolverCacheFn.Call() may panic if the func is not found // dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() { defer func() {
if rec := recover(); rec != nil { if rec := recover(); rec != nil {
@@ -202,13 +304,14 @@ func (r *registryConfigurator) flushDNSCache() error {
ret, _, err := dnsFlushResolverCacheFn.Call() ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 { if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) { if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("DnsFlushResolverCache failed: %w", err) log.Errorf("DnsFlushResolverCache failed: %v", err)
return
} }
return fmt.Errorf("DnsFlushResolverCache failed") log.Errorf("DnsFlushResolverCache failed")
return
} }
log.Info("flushed DNS cache") log.Info("flushed DNS cache")
return nil
} }
func (r *registryConfigurator) updateSearchDomains(domains []string) error { func (r *registryConfigurator) updateSearchDomains(domains []string) error {
@@ -263,9 +366,7 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := r.flushDNSCache(); err != nil { go r.flushDNSCache()
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }

View File

@@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() {
} }
} }
log.Debugf("extra match domains: %v", s.extraDomains) log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err) log.Errorf("failed to apply DNS host manager update: %v", err)

View File

@@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string {
} }
func (w *mocWGIface) Address() wgaddr.Address { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return wgaddr.Address{ return wgaddr.Address{
IP: ip, IP: netip.MustParseAddr("100.66.100.1"),
Network: network, Network: netip.MustParsePrefix("100.66.100.0/24"),
} }
} }
@@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil { if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err) t.Errorf("set packet filter: %v", err)

View File

@@ -24,11 +24,15 @@ type ServiceViaMemory struct {
} }
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{ s := &ServiceViaMemory{
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(), runtimeIP: lastIP.String(),
runtimePort: defaultPort, runtimePort: defaultPort,
} }
return s return s
@@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
} }
firstLayerDecoder := layers.LayerTypeIPv4 firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil { if s.wgInterface.Address().IP.Is6() {
firstLayerDecoder = layers.LayerTypeIPv6 firstLayerDecoder = layers.LayerTypeIPv6
} }

View File

@@ -1,33 +0,0 @@
package dns
import (
"net"
"testing"
nbnet "github.com/netbirdio/netbird/util/net"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -30,9 +30,12 @@ const (
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS" systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusResolvConfModeForeign = "foreign" systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
dnsSecDisabled = "no"
) )
type systemdDbusConfigurator struct { type systemdDbusConfigurator struct {
@@ -95,9 +98,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
Family: unix.AF_INET, Family: unix.AF_INET,
Address: ipAs4[:], Address: ipAs4[:],
} }
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}) if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
if err != nil { return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err) }
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSSEC to 'no': %v", err)
} }
var ( var (

View File

@@ -3,6 +3,7 @@ package dns
import ( import (
"context" "context"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@@ -23,8 +24,8 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ net.IP, _ netip.Addr,
_ *net.IPNet, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string, domain string,

View File

@@ -4,7 +4,7 @@ package dns
import ( import (
"context" "context"
"net" "net/netip"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -19,8 +19,8 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ net.IP, _ netip.Addr,
_ *net.IPNet, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@@ -18,16 +19,16 @@ import (
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP net.IP lIP netip.Addr
lNet *net.IPNet lNet netip.Prefix
interfaceName string interfaceName string
} }
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
interfaceName string, interfaceName string,
ip net.IP, ip netip.Addr,
net *net.IPNet, net netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} }
client.DialTimeout = timeout client.DialTimeout = timeout
upstreamIP := net.ParseIP(upstreamHost) upstreamIP, err := netip.ParseAddr(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil { if err != nil {
@@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName) index, err := getInterfaceIndex(interfaceName)
if err != nil { if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err) log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
@@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: ip, IP: ip.AsSlice(),
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: dialTimeout, Timeout: dialTimeout,

View File

@@ -2,7 +2,7 @@ package dns
import ( import (
"context" "context"
"net" "net/netip"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {

View File

@@ -5,7 +5,6 @@ package dns
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -18,5 +17,4 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
} }

View File

@@ -1,7 +1,6 @@
package dns package dns
import ( import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -13,6 +12,5 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error) GetInterfaceGUIDString() (string, error)
} }

View File

@@ -38,6 +38,7 @@ import (
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
@@ -120,8 +121,10 @@ type EngineConfig struct {
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool
BlockInbound bool
BlockLANAccess bool LazyConnectionEnabled bool
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -134,6 +137,8 @@ type Engine struct {
// peerConns is a map that holds all the peers that are known to this peer // peerConns is a map that holds all the peers that are known to this peer
peerStore *peerstore.Store peerStore *peerstore.Store
connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc afterPeerHook nbnet.RemoveHookFunc
@@ -170,7 +175,8 @@ type Engine struct {
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server sshServer nbssh.Server
statusRecorder *peer.Status statusRecorder *peer.Status
peerConnDispatcher *dispatcher.ConnectionDispatcher
firewall firewallManager.Manager firewall firewallManager.Manager
routeManager routemanager.Manager routeManager routemanager.Manager
@@ -235,6 +241,8 @@ func NewEngine(
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
} }
path := statemanager.GetDefaultStatePath()
if runtime.GOOS == "ios" { if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) { if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath) err := createFile(mobileDep.StateFilePath)
@@ -244,11 +252,9 @@ func NewEngine(
} }
} }
engine.stateManager = statemanager.New(mobileDep.StateFilePath) path = mobileDep.StateFilePath
}
if path := statemanager.GetDefaultStatePath(); path != "" {
engine.stateManager = statemanager.New(path)
} }
engine.stateManager = statemanager.New(path)
return engine return engine
} }
@@ -262,6 +268,10 @@ func (e *Engine) Stop() error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if e.connMgr != nil {
e.connMgr.Close()
}
// stopping network monitor first to avoid starting the engine again // stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil { if e.networkMonitor != nil {
e.networkMonitor.Stop() e.networkMonitor.Stop()
@@ -297,8 +307,7 @@ func (e *Engine) Stop() error {
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
err := e.removeAllPeers() if err := e.removeAllPeers(); err != nil {
if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
} }
@@ -350,6 +359,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("new wg interface: %w", err) return fmt.Errorf("new wg interface: %w", err)
} }
e.wgInterface = wgIface e.wgInterface = wgIface
e.statusRecorder.SetWgIface(wgIface)
// start flow manager right after interface creation // start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey() publicKey := e.config.WgPrivateKey.PublicKey()
@@ -371,7 +381,6 @@ func (e *Engine) Start() error {
return fmt.Errorf("run rosenpass manager: %w", err) return fmt.Errorf("run rosenpass manager: %w", err)
} }
} }
e.stateManager.Start() e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer() initialRoutes, dnsServer, err := e.newDnsServer()
@@ -405,8 +414,7 @@ func (e *Engine) Start() error {
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
err = e.wgInterfaceCreate() if err = e.wgInterfaceCreate(); err != nil {
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close() e.close()
return fmt.Errorf("create wg interface: %w", err) return fmt.Errorf("create wg interface: %w", err)
@@ -423,7 +431,8 @@ func (e *Engine) Start() error {
return fmt.Errorf("up wg interface: %w", err) return fmt.Errorf("up wg interface: %w", err)
} }
if e.firewall != nil { // if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall) e.acl = acl.NewDefaultManager(e.firewall)
} }
@@ -442,6 +451,11 @@ func (e *Engine) Start() error {
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
} }
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start() e.srWatcher.Start()
@@ -450,7 +464,6 @@ func (e *Engine) Start() error {
// starting network monitor at the very last to avoid disruptions // starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor() e.startNetworkMonitor()
return nil return nil
} }
@@ -475,11 +488,9 @@ func (e *Engine) createFirewall() error {
} }
func (e *Engine) initFirewall() error { func (e *Engine) initFirewall() error {
if e.firewall.IsServerRouteSupported() { if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { e.close()
e.close() return fmt.Errorf("enable server router: %w", err)
return fmt.Errorf("enable server router: %w", err)
}
} }
if e.config.BlockLANAccess { if e.config.BlockLANAccess {
@@ -513,6 +524,11 @@ func (e *Engine) initFirewall() error {
} }
func (e *Engine) blockLanAccess() { func (e *Engine) blockLanAccess() {
if e.config.BlockInbound {
// no need to set up extra deny rules if inbound is already blocked in general
return
}
var merr *multierror.Error var merr *multierror.Error
// TODO: keep this updated // TODO: keep this updated
@@ -550,6 +566,16 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
currentPeer, ok := e.peerStore.PeerConn(peerPubKey)
if !ok {
continue
}
if currentPeer.AgentVersionString() != p.AgentVersion {
modified = append(modified, p)
continue
}
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey) allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
if !ok { if !ok {
continue continue
@@ -559,8 +585,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
continue continue
} }
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()) if err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()); err != nil {
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err) log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
} }
} }
@@ -621,16 +646,11 @@ func (e *Engine) removePeer(peerKey string) error {
e.sshServer.RemoveAuthorizedKey(peerKey) e.sshServer.RemoveAuthorizedKey(peerKey)
} }
defer func() { e.connMgr.RemovePeerConn(peerKey)
err := e.statusRecorder.RemovePeer(peerKey)
if err != nil {
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
}
}()
conn, exists := e.peerStore.Remove(peerKey) err := e.statusRecorder.RemovePeer(peerKey)
if exists { if err != nil {
conn.Close() log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
} }
return nil return nil
} }
@@ -780,56 +800,58 @@ func isNil(server nbssh.Server) bool {
} }
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Infof("SSH server is disabled because inbound connections are blocked")
return nil
}
if !e.config.ServerSSHAllowed { if !e.config.ServerSSHAllowed {
log.Warnf("running SSH server is not permitted") log.Info("SSH server is not enabled")
return nil return nil
} else {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
return nil
} }
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
return nil
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
@@ -952,12 +974,33 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil return nil
} }
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
log.Errorf("failed to update lazy connection feature flag: %v", err)
}
if e.firewall != nil { if e.firewall != nil {
if localipfw, ok := e.firewall.(localIpUpdater); ok { if localipfw, ok := e.firewall.(localIpUpdater); ok {
if err := localipfw.UpdateLocalIPs(); err != nil { if err := localipfw.UpdateLocalIPs(); err != nil {
log.Errorf("failed to update local IPs: %v", err) log.Errorf("failed to update local IPs: %v", err)
} }
} }
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes.
// This needs to be toggled before applying routes.
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := e.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
}
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
} }
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
@@ -965,7 +1008,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update routes: %v", err)
} }
if e.acl != nil { if e.acl != nil {
@@ -976,7 +1019,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules // Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil { forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err) log.Errorf("failed to update forward rules, err: %v", err)
} }
@@ -1022,14 +1066,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
protoDNSConfig := networkMap.GetDNSConfig() // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
if protoDNSConfig == nil { excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
protoDNSConfig = &mgmProto.DNSConfig{} e.connMgr.SetExcludeList(excludedLazyPeers)
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial e.networkSerial = serial
@@ -1065,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix.Masked(),
Domains: domain.FromPunycodeList(protoRoute.Domains), Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID), NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
@@ -1099,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
return entries return entries
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(), ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0), CustomZones: make([]nbdns.CustomZone, 0),
@@ -1155,7 +1194,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
IP: strings.Join(offlinePeer.GetAllowedIps(), ","), IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
PubKey: offlinePeer.GetWgPubKey(), PubKey: offlinePeer.GetWgPubKey(),
FQDN: offlinePeer.GetFqdn(), FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected, ConnStatus: peer.StatusIdle,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
@@ -1191,12 +1230,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerIPs = append(peerIPs, allowedNetIP) peerIPs = append(peerIPs, allowedNetIP)
} }
conn, err := e.createPeerConn(peerKey, peerIPs) conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
if err != nil { if err != nil {
return fmt.Errorf("create peer connection: %w", err) return fmt.Errorf("create peer connection: %w", err)
} }
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
conn.Close() conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey) return fmt.Errorf("peer already exists: %s", peerKey)
} }
@@ -1205,17 +1249,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
conn.AddBeforeAddPeerHook(e.beforePeerHook) conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook) conn.AddAfterRemovePeerHook(e.afterPeerHook)
} }
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
conn.Open()
return nil return nil
} }
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) { func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentVersion string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey) log.Debugf("creating peer connection %s", pubKey)
wgConfig := peer.WgConfig{ wgConfig := peer.WgConfig{
@@ -1229,11 +1266,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
// randomize connection timeout // randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{ config := peer.ConnConfig{
Key: pubKey, Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(), LocalKey: e.config.WgPrivateKey.PublicKey().String(),
Timeout: timeout, AgentVersion: agentVersion,
WgConfig: wgConfig, Timeout: timeout,
LocalWgPort: e.config.WgPort, WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
RosenpassConfig: peer.RosenpassConfig{ RosenpassConfig: peer.RosenpassConfig{
PubKey: e.getRosenpassPubKey(), PubKey: e.getRosenpassPubKey(),
Addr: e.getRosenpassAddr(), Addr: e.getRosenpassAddr(),
@@ -1249,7 +1287,16 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
}, },
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore) serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
PeerConnDispatcher: e.peerConnDispatcher,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1270,7 +1317,7 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
conn, ok := e.peerStore.PeerConn(msg.Key) conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
if !ok { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)
} }
@@ -1406,6 +1453,7 @@ func (e *Engine) close() {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
} }
e.wgInterface = nil e.wgInterface = nil
e.statusRecorder.SetWgIface(nil)
} }
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
@@ -1578,13 +1626,39 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states. // and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool { func (e *Engine) RunHealthProbes() bool {
e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy() signalHealthy := e.signal.IsHealthy()
log.Debugf("signal health check: healthy=%t", signalHealthy) log.Debugf("signal health check: healthy=%t", signalHealthy)
managementHealthy := e.mgmClient.IsHealthy() managementHealthy := e.mgmClient.IsHealthy()
log.Debugf("management health check: healthy=%t", managementHealthy) log.Debugf("management health check: healthy=%t", managementHealthy)
results := append(e.probeSTUNs(), e.probeTURNs()...) stuns := slices.Clone(e.STUNs)
turns := slices.Clone(e.TURNs)
if e.wgInterface != nil {
stats, err := e.wgInterface.GetStats()
if err != nil {
log.Warnf("failed to get wireguard stats: %v", err)
e.syncMsgMux.Unlock()
return false
}
for _, key := range e.peerStore.PeersPubKey() {
// wgStats could be zero value, in which case we just reset the stats
wgStats, ok := stats[key]
if !ok {
continue
}
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
}
e.syncMsgMux.Unlock()
results := e.probeICE(stuns, turns)
e.statusRecorder.UpdateRelayStates(results) e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true relayHealthy := true
@@ -1596,37 +1670,16 @@ func (e *Engine) RunHealthProbes() bool {
} }
log.Debugf("relay health check: healthy=%t", relayHealthy) log.Debugf("relay health check: healthy=%t", relayHealthy)
for _, key := range e.peerStore.PeersPubKey() {
wgStats, err := e.wgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
continue
}
// wgStats could be zero value, in which case we just reset the stats
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
allHealthy := signalHealthy && managementHealthy && relayHealthy allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy) log.Debugf("all health checks completed: healthy=%t", allHealthy)
return allHealthy return allHealthy
} }
func (e *Engine) probeSTUNs() []relay.ProbeResult { func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
e.syncMsgMux.Lock() return append(
stuns := slices.Clone(e.STUNs) relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
e.syncMsgMux.Unlock() relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
)
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
}
func (e *Engine) probeTURNs() []relay.ProbeResult {
e.syncMsgMux.Lock()
turns := slices.Clone(e.TURNs)
e.syncMsgMux.Unlock()
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
} }
// restartEngine restarts the engine by cancelling the client context // restartEngine restarts the engine by cancelling the client context
@@ -1738,9 +1791,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
} }
// GetWgAddr returns the wireguard address // GetWgAddr returns the wireguard address
func (e *Engine) GetWgAddr() net.IP { func (e *Engine) GetWgAddr() netip.Addr {
if e.wgInterface == nil { if e.wgInterface == nil {
return nil return netip.Addr{}
} }
return e.wgInterface.Address().IP return e.wgInterface.Address().IP
} }
@@ -1750,6 +1803,10 @@ func (e *Engine) updateDNSForwarder(
enabled bool, enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry, fwdEntries []*dnsfwd.ForwarderEntry,
) { ) {
if e.config.DisableServerRoutes {
return
}
if !enabled { if !enabled {
if e.dnsForwardMgr == nil { if e.dnsForwardMgr == nil {
return return
@@ -1805,29 +1862,24 @@ func (e *Engine) Address() (netip.Addr, error) {
return netip.Addr{}, errors.New("wireguard interface not initialized") return netip.Addr{}, errors.New("wireguard interface not initialized")
} }
addr := e.wgInterface.Address() return e.wgInterface.Address().IP, nil
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
}
return ip.Unmap(), nil
} }
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error { func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
if e.firewall == nil { if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules") log.Warn("firewall is disabled, not updating forwarding rules")
return nil return nil, nil
} }
if len(rules) == 0 { if len(rules) == 0 {
if e.ingressGatewayMgr == nil { if e.ingressGatewayMgr == nil {
return nil return nil, nil
} }
err := e.ingressGatewayMgr.Close() err := e.ingressGatewayMgr.Close()
e.ingressGatewayMgr = nil e.ingressGatewayMgr = nil
e.statusRecorder.SetIngressGwMgr(nil) e.statusRecorder.SetIngressGwMgr(nil)
return err return nil, err
} }
if e.ingressGatewayMgr == nil { if e.ingressGatewayMgr == nil {
@@ -1878,7 +1930,35 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
log.Errorf("failed to update forwarding rules: %v", err) log.Errorf("failed to update forwarding rules: %v", err)
} }
return nberrors.FormatErrorOrNil(merr) return forwardingRules, nberrors.FormatErrorOrNil(merr)
}
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
excludedPeers := make(map[string]bool)
for _, r := range routes {
if r.Peer == "" {
continue
}
if !excludedPeers[r.Peer] {
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
excludedPeers[r.Peer] = true
}
}
for _, r := range rules {
ip := r.TranslatedAddress
for _, p := range peers {
for _, allowedIP := range p.GetAllowedIps() {
if allowedIP != ip.String() {
continue
}
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
excludedPeers[p.GetWgPubKey()] = true
}
}
}
return excludedPeers
} }
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.

View File

@@ -28,8 +28,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@@ -38,6 +36,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@@ -53,6 +52,7 @@ import (
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@@ -93,12 +93,16 @@ type MockWGIface struct {
GetFilterFunc func() device.PacketFilter GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetStatsFunc func() (map[string]configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error) GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net GetNetFunc func() *netstack.Net
} }
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc() return m.GetInterfaceGUIDStringFunc()
} }
@@ -171,8 +175,8 @@ func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
return m.GetWGDeviceFunc() return m.GetWGDeviceFunc()
} }
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (m *MockWGIface) GetStats() (map[string]configurer.WGStats, error) {
return m.GetStatsFunc(peerKey) return m.GetStatsFunc()
} }
func (m *MockWGIface) GetProxy() wgproxy.Proxy { func (m *MockWGIface) GetProxy() wgproxy.Proxy {
@@ -371,13 +375,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return nil
},
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
@@ -400,6 +404,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
type testCase struct { type testCase struct {
name string name string
@@ -770,6 +776,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.routeManager = mockRouteManager engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{} engine.dnsServer = &dns.MockServer{}
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() { defer func() {
exitErr := engine.Stop() exitErr := engine.Stop()
@@ -966,6 +974,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
} }
engine.dnsServer = mockDNSServer engine.dnsServer = mockDNSServer
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() { defer func() {
exitErr := engine.Stop() exitErr := engine.Stop()
@@ -1476,7 +1486,7 @@ func getConnectedPeers(e *Engine) int {
i := 0 i := 0
for _, id := range e.peerStore.PeersPubKey() { for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id) conn, _ := e.peerStore.PeerConn(id)
if conn.Status() == peer.StatusConnected { if conn.IsConnected() {
i++ i++
} }
} }

View File

@@ -35,6 +35,7 @@ type wgIfaceBase interface {
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
} }

View File

@@ -0,0 +1,9 @@
//go:build !linux || android
package activity
import "net"
var (
listenIP = net.IP{127, 0, 0, 1}
)

View File

@@ -0,0 +1,10 @@
//go:build !android
package activity
import "net"
var (
// use this ip to avoid eBPF proxy congestion
listenIP = net.IP{127, 0, 1, 1}
)

View File

@@ -0,0 +1,106 @@
package activity
import (
"fmt"
"net"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
wgIface lazyconn.WGIface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
done sync.Mutex
isClosed atomic.Bool // use to avoid error log when closing the listener
}
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
wgIface: wgIface,
peerCfg: cfg,
}
conn, err := d.newConn()
if err != nil {
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
}
d.conn = conn
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
if err := d.createEndpoint(); err != nil {
return nil, err
}
d.done.Lock()
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
return d, nil
}
func (d *Listener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Debugf("exit from activity listener")
} else {
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
}
break
}
if n < 1 {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
break
}
if err := d.removeEndpoint(); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
d.done.Unlock()
}
func (d *Listener) Close() {
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
if err := d.conn.Close(); err != nil {
d.peerCfg.Log.Errorf("failed to close UDP listener: %s", err)
}
d.done.Lock()
}
func (d *Listener) removeEndpoint() error {
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}
func (d *Listener) createEndpoint() error {
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
}
func (d *Listener) newConn() (*net.UDPConn, error) {
addr := &net.UDPAddr{
Port: 0,
IP: listenIP,
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
log.Errorf("failed to create activity listener on %s: %s", addr, err)
return nil, err
}
return conn, nil
}

View File

@@ -0,0 +1,41 @@
package activity
import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestNewListener(t *testing.T) {
peer := &MocPeer{
PeerID: "examplePublicKey1",
}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
l, err := NewListener(MocWGIface{}, cfg)
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
chanClosed := make(chan struct{})
go func() {
defer close(chanClosed)
l.ReadPackets()
}()
time.Sleep(1 * time.Second)
l.Close()
select {
case <-chanClosed:
case <-time.After(time.Second):
}
}

View File

@@ -0,0 +1,95 @@
package activity
import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type Manager struct {
OnActivityChan chan peerid.ConnID
wgIface lazyconn.WGIface
peers map[peerid.ConnID]*Listener
done chan struct{}
mu sync.Mutex
}
func NewManager(wgIface lazyconn.WGIface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]*Listener),
done: make(chan struct{}),
}
return m
}
func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.peers[peerCfg.PeerConnID]; ok {
log.Warnf("activity listener already exists for: %s", peerCfg.PublicKey)
return nil
}
listener, err := NewListener(m.wgIface, peerCfg)
if err != nil {
return err
}
m.peers[peerCfg.PeerConnID] = listener
go m.waitForTraffic(listener, peerCfg.PeerConnID)
return nil
}
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
m.mu.Lock()
defer m.mu.Unlock()
listener, ok := m.peers[peerConnID]
if !ok {
return
}
log.Debugf("removing activity listener")
delete(m.peers, peerConnID)
listener.Close()
}
func (m *Manager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
close(m.done)
for peerID, listener := range m.peers {
delete(m.peers, peerID)
listener.Close()
}
}
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
listener.ReadPackets()
m.mu.Lock()
if _, ok := m.peers[peerConnID]; !ok {
m.mu.Unlock()
return
}
delete(m.peers, peerConnID)
m.mu.Unlock()
m.notify(peerConnID)
}
func (m *Manager) notify(peerConnID peerid.ConnID) {
select {
case <-m.done:
case m.OnActivityChan <- peerConnID:
}
}

View File

@@ -0,0 +1,162 @@
package activity
import (
"net"
"net/netip"
"testing"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
PeerID string
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
type MocWGIface struct {
}
func (m MocWGIface) RemovePeer(string) error {
return nil
}
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
func TestManager_MonitorPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
select {
case peerConnID := <-mgr.OnActivityChan:
if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", peerConnID)
}
case <-time.After(1 * time.Second):
}
}
func TestManager_RemovePeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
if err := trigger(addr); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
select {
case <-mgr.OnActivityChan:
t.Fatal("should not have active activity")
case <-time.After(1 * time.Second):
}
}
func TestManager_MultiPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
peer2 := &MocPeer{}
peerCfg2 := lazyconn.PeerConfig{
PublicKey: peer2.PeerID,
PeerConnID: peer2.ConnID(),
Log: log.WithField("peer", "examplePublicKey2"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := mgr.MonitorPeerActivity(peerCfg2); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
for i := 0; i < 2; i++ {
select {
case <-mgr.OnActivityChan:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for activity")
}
}
}
func trigger(addr string) error {
// Create a connection to the destination UDP address and port
conn, err := net.Dial("udp", addr)
if err != nil {
return err
}
defer conn.Close()
// Write the bytes to the UDP connection
_, err = conn.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05})
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,32 @@
/*
Package lazyconn provides mechanisms for managing lazy connections, which activate on demand to optimize resource usage and establish connections efficiently.
## Overview
The package includes a `Manager` component responsible for:
- Managing lazy connections activated on-demand
- Managing inactivity monitors for lazy connections (based on peer disconnection events)
- Maintaining a list of excluded peers that should always have permanent connections
- Handling remote peer connection initiatives based on peer signaling
## Thread-Safe Operations
The `Manager` ensures thread safety across multiple operations, categorized by caller:
- **Engine (single goroutine)**:
- `AddPeer`: Adds a peer to the connection manager.
- `RemovePeer`: Removes a peer from the connection manager.
- `ActivatePeer`: Activates a lazy connection for a peer. This come from Signal client
- `ExcludePeer`: Marks peers for a permanent connection. Like router peers and other peers that should always have a connection.
- **Connection Dispatcher (any peer routine)**:
- `onPeerConnected`: Suspend the inactivity monitor for an active peer connection.
- `onPeerDisconnected`: Starts the inactivity monitor for a disconnected peer.
- **Activity Manager**:
- `onPeerActivity`: Run peer.Open(context).
- **Inactivity Monitor**:
- `onPeerInactivityTimedOut`: Close peer connection and restart activity monitor.
*/
package lazyconn

View File

@@ -0,0 +1,26 @@
package lazyconn
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
)
func IsLazyConnEnabledByEnv() bool {
val := os.Getenv(EnvEnableLazyConn)
if val == "" {
return false
}
enabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
return false
}
return enabled
}

View File

@@ -0,0 +1,70 @@
package inactivity
import (
"context"
"time"
peer "github.com/netbirdio/netbird/client/internal/peer/id"
)
const (
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
MinimumInactivityThreshold = 3 * time.Minute
)
type Monitor struct {
id peer.ConnID
timer *time.Timer
cancel context.CancelFunc
inactivityThreshold time.Duration
}
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
i := &Monitor{
id: peerID,
timer: time.NewTimer(0),
inactivityThreshold: threshold,
}
i.timer.Stop()
return i
}
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
i.timer.Reset(i.inactivityThreshold)
defer i.timer.Stop()
ctx, i.cancel = context.WithCancel(ctx)
defer func() {
defer i.cancel()
select {
case <-i.timer.C:
default:
}
}()
select {
case <-i.timer.C:
select {
case timeoutChan <- i.id:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
func (i *Monitor) Stop() {
if i.cancel == nil {
return
}
i.cancel()
}
func (i *Monitor) PauseTimer() {
i.timer.Stop()
}
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}

View File

@@ -0,0 +1,156 @@
package inactivity
import (
"context"
"testing"
"time"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
func TestInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestReuseInactivityMonitor(t *testing.T) {
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
for i := 2; i > 0; i-- {
exitChan := make(chan struct{})
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
go func() {
defer close(exitChan)
im.Start(testTimeoutCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
testTimeoutCancel()
}
}
func TestStopInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
go func() {
time.Sleep(3 * time.Second)
im.Stop()
}()
select {
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestPauseInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
defer testTimeoutCancel()
p := &MocPeer{}
trashHold := time.Second * 3
im := NewInactivityMonitor(p.ConnID(), trashHold)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(ctx, timeoutChan)
}()
time.Sleep(1 * time.Second) // grant time to start the monitor
im.PauseTimer()
// check to do not receive timeout
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
defer thresholdCancel()
select {
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-thresholdCtx.Done():
// test ok
case <-tCtx.Done():
t.Fatal("test timed out")
}
// test reset timer
im.ResetTimer()
select {
case <-tCtx.Done():
t.Fatal("test timed out")
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
// expected timeout
}
}

View File

@@ -0,0 +1,404 @@
package manager
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
const (
watcherActivity watcherType = iota
watcherInactivity
)
type watcherType int
type managedPeer struct {
peerCfg *lazyconn.PeerConfig
expectedWatcher watcherType
}
type Config struct {
InactivityThreshold *time.Duration
}
// Manager manages lazy connections
// It is responsible for:
// - Managing lazy connections activated on-demand
// - Managing inactivity monitors for lazy connections (based on peer disconnection events)
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
type Manager struct {
peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration
connStateListener *dispatcher.ConnectionListener
managedPeers map[string]*lazyconn.PeerConfig
managedPeersByConnID map[peerid.ConnID]*managedPeer
excludes map[string]lazyconn.PeerConfig
managedPeersMu sync.Mutex
activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
cancel context.CancelFunc
onInactive chan peerid.ConnID
}
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
log.Infof("setup lazy connection service")
m := &Manager{
peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold,
managedPeers: make(map[string]*lazyconn.PeerConfig),
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
onInactive: make(chan peerid.ConnID),
}
if config.InactivityThreshold != nil {
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
m.inactivityThreshold = *config.InactivityThreshold
} else {
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
}
}
m.connStateListener = &dispatcher.ConnectionListener{
OnConnected: m.onPeerConnected,
OnDisconnected: m.onPeerDisconnected,
}
connStateDispatcher.AddListener(m.connStateListener)
return m
}
// Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) {
defer m.close()
ctx, m.cancel = context.WithCancel(ctx)
for {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID)
}
}
}
// ExcludePeer marks peers for a permanent connection
// It removes peers from the managed list if they are added to the exclude list
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
// this case, we suppose that the connection status is connected or connecting.
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
added := make([]string, 0)
excludes := make(map[string]lazyconn.PeerConfig, len(peerConfigs))
for _, peerCfg := range peerConfigs {
log.Infof("update excluded lazy connection list with peer: %s", peerCfg.PublicKey)
excludes[peerCfg.PublicKey] = peerCfg
}
// if a peer is newly added to the exclude list, remove from the managed peers list
for pubKey, peerCfg := range excludes {
if _, wasExcluded := m.excludes[pubKey]; wasExcluded {
continue
}
added = append(added, pubKey)
peerCfg.Log.Infof("peer newly added to lazy connection exclude list")
m.removePeer(pubKey)
}
// if a peer has been removed from exclude list then it should be added to the managed peers
for pubKey, peerCfg := range m.excludes {
if _, stillExcluded := excludes[pubKey]; stillExcluded {
continue
}
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
if err := m.addActivePeer(ctx, peerCfg); err != nil {
log.Errorf("failed to add peer to lazy connection manager: %s", err)
continue
}
}
m.excludes = excludes
return added
}
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
peerCfg.Log.Debugf("adding peer to lazy connection manager")
_, exists := m.excludes[peerCfg.PublicKey]
if exists {
return true, nil
}
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return false, nil
}
if err := m.activityManager.MonitorPeerActivity(peerCfg); err != nil {
return false, err
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherActivity,
}
return false, nil
}
// AddActivePeers adds a list of peers to the lazy connection manager
// suppose these peers was in connected or in connecting states
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
for _, cfg := range peerCfg {
if _, ok := m.managedPeers[cfg.PublicKey]; ok {
cfg.Log.Errorf("peer already managed")
continue
}
if err := m.addActivePeer(ctx, cfg); err != nil {
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
return err
}
}
return nil
}
func (m *Manager) RemovePeer(peerID string) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.removePeer(peerID)
}
// ActivatePeer activates a peer connection when a signal message is received
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
cfg, ok := m.managedPeers[peerID]
if !ok {
return false
}
mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
return false
}
// signal messages coming continuously after success activation, with this avoid the multiple activation
if mp.expectedWatcher == watcherInactivity {
return false
}
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
im, ok := m.inactivityMonitors[cfg.PeerConnID]
if !ok {
cfg.Log.Errorf("inactivity monitor not found for peer")
return false
}
mp.peerCfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive)
return true
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return nil
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherInactivity,
}
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
go im.Start(ctx, m.onInactive)
return nil
}
func (m *Manager) removePeer(peerID string) {
cfg, ok := m.managedPeers[peerID]
if !ok {
return
}
cfg.Log.Infof("removing lazy peer")
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
im.Stop()
delete(m.inactivityMonitors, cfg.PeerConnID)
cfg.Log.Debugf("inactivity monitor stopped")
}
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
delete(m.managedPeers, peerID)
delete(m.managedPeersByConnID, cfg.PeerConnID)
}
func (m *Manager) close() {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.cancel()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
iw.Stop()
}
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
log.Infof("lazy connection manager closed")
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherActivity {
mp.peerCfg.Log.Warnf("ignore activity event")
return
}
mp.peerCfg.Log.Infof("detected peer activity")
mp.expectedWatcher = watcherInactivity
mp.peerCfg.Log.Infof("starting inactivity monitor")
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
return
}
mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity
// just in case free up
m.inactivityMonitors[peerConnID].PauseTimer()
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
return
}
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
iw.PauseTimer()
}
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
return
}
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
iw.ResetTimer()
}

View File

@@ -0,0 +1,16 @@
package lazyconn
import (
"net/netip"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/id"
)
type PeerConfig struct {
PublicKey string
AllowedIPs []netip.Prefix
PeerConnID id.ConnID
Log *log.Entry
}

View File

@@ -0,0 +1,41 @@
package lazyconn
import (
"strings"
"github.com/hashicorp/go-version"
)
var (
minVersion = version.Must(version.NewVersion("0.45.0"))
)
func IsSupported(agentVersion string) bool {
if agentVersion == "development" {
return true
}
// filter out versions like this: a6c5960, a7d5c522, d47be154
if !strings.Contains(agentVersion, ".") {
return false
}
normalizedVersion := normalizeVersion(agentVersion)
inputVer, err := version.NewVersion(normalizedVersion)
if err != nil {
return false
}
return inputVer.GreaterThanOrEqual(minVersion)
}
func normalizeVersion(version string) string {
// Remove prefixes like 'v' or 'a'
if len(version) > 0 && (version[0] == 'v' || version[0] == 'a') {
version = version[1:]
}
// Remove any suffixes like '-dirty', '-dev', '-SNAPSHOT', etc.
parts := strings.Split(version, "-")
return parts[0]
}

View File

@@ -0,0 +1,31 @@
package lazyconn
import "testing"
func TestIsSupported(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"development", true},
{"0.45.0", true},
{"v0.45.0", true},
{"0.45.1", true},
{"0.45.1-SNAPSHOT-559e6731", true},
{"v0.45.1-dev", true},
{"a7d5c522", false},
{"0.9.6", false},
{"0.9.6-SNAPSHOT", false},
{"0.9.6-SNAPSHOT-2033650", false},
{"meta_wt_version", false},
{"v0.31.1-dev", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
if got := IsSupported(tt.version); got != tt.want {
t.Errorf("IsSupported() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,14 @@
package lazyconn
import (
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
}

View File

@@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
// fallback if mark rules are not in place // fallback if mark rules are not in place
wgnet := c.iface.Address().Network wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice()) return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
} }
// mapRxPackets maps packet counts to RX based on flow direction // mapRxPackets maps packet counts to RX based on flow direction
@@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
// fallback if marks are not set // fallback if marks are not set
wgaddr := c.iface.Address().IP wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
switch { switch {
case wgaddr.Equal(src): case wgaddr == srcIP:
return nftypes.Egress return nftypes.Egress
case wgaddr.Equal(dst): case wgaddr == dstIP:
return nftypes.Ingress return nftypes.Ingress
case wgnetwork.Contains(src): case wgnetwork.Contains(srcIP):
// netbird network -> resource network // netbird network -> resource network
return nftypes.Ingress return nftypes.Ingress
case wgnetwork.Contains(dst): case wgnetwork.Contains(dstIP):
// resource network -> netbird network // resource network -> netbird network
return nftypes.Egress return nftypes.Egress
} }

View File

@@ -2,7 +2,7 @@ package logger
import ( import (
"context" "context"
"net" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -23,17 +23,16 @@ type Logger struct {
rcvChan atomic.Pointer[rcvChan] rcvChan atomic.Pointer[rcvChan]
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgIfaceIPNet net.IPNet wgIfaceNet netip.Prefix
dnsCollection atomic.Bool dnsCollection atomic.Bool
exitNodeCollection atomic.Bool exitNodeCollection atomic.Bool
Store types.Store Store types.Store
} }
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger { func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
return &Logger{ return &Logger{
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgIfaceIPNet: wgIfaceIPNet, wgIfaceNet: wgIfaceIPNet,
Store: store.NewMemoryStore(), Store: store.NewMemoryStore(),
} }
} }
@@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
var isSrcExitNode bool var isSrcExitNode bool
var isDestExitNode bool var isDestExitNode bool
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) { if !l.wgIfaceNet.Contains(event.SourceIP) {
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
} }
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) { if !l.wgIfaceNet.Contains(event.DestIP) {
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
} }

View File

@@ -1,7 +1,7 @@
package logger_test package logger_test
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -12,7 +12,7 @@ import (
) )
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
logger := logger.New(nil, net.IPNet{}) logger := logger.New(nil, netip.Prefix{})
logger.Enable() logger.Enable()
event := types.EventFields{ event := types.EventFields{

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net" "net/netip"
"runtime" "runtime"
"sync" "sync"
"time" "time"
@@ -34,11 +34,11 @@ type Manager struct {
// NewManager creates a new netflow manager // NewManager creates a new netflow manager
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var ipNet net.IPNet var prefix netip.Prefix
if iface != nil { if iface != nil {
ipNet = *iface.Address().Network prefix = iface.Address().Network
} }
flowLogger := logger.New(statusRecorder, ipNet) flowLogger := logger.New(statusRecorder, prefix)
var ct nftypes.ConnTracker var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
@@ -123,8 +123,14 @@ func (m *Manager) disableFlow() error {
m.logger.Close() m.logger.Close()
if m.receiverClient != nil { if m.receiverClient == nil {
return m.receiverClient.Close() return nil
}
err := m.receiverClient.Close()
m.receiverClient = nil
if err != nil {
return fmt.Errorf("close: %w", err)
} }
return nil return nil

View File

@@ -1,7 +1,7 @@
package netflow package netflow
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
func TestManager_Update(t *testing.T) { func TestManager_Update(t *testing.T) {
mockIFace := &mockIFaceMapper{ mockIFace := &mockIFaceMapper{
address: wgaddr.Address{ address: wgaddr.Address{
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.1/32"),
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
}, },
isUserspaceBind: true, isUserspaceBind: true,
} }
@@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) {
func TestManager_Update_TokenPreservation(t *testing.T) { func TestManager_Update_TokenPreservation(t *testing.T) {
mockIFace := &mockIFaceMapper{ mockIFace := &mockIFaceMapper{
address: wgaddr.Address{ address: wgaddr.Address{
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.1/32"),
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
}, },
isUserspaceBind: true, isUserspaceBind: true,
} }

View File

@@ -17,8 +17,12 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -26,32 +30,20 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case connPriorityNone:
return "None"
case connPriorityRelay:
return "PriorityRelay"
case connPriorityICETurn:
return "PriorityICETurn"
case connPriorityICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}
const ( const (
defaultWgKeepAlive = 25 * time.Second defaultWgKeepAlive = 25 * time.Second
connPriorityNone ConnPriority = 0
connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 2
connPriorityICEP2P ConnPriority = 3
) )
type ServiceDependencies struct {
StatusRecorder *Status
Signaler *Signaler
IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher
}
type WgConfig struct { type WgConfig struct {
WgListenPort int WgListenPort int
RemoteKey string RemoteKey string
@@ -76,6 +68,8 @@ type ConnConfig struct {
// LocalKey is a public key of a local peer // LocalKey is a public key of a local peer
LocalKey string LocalKey string
AgentVersion string
Timeout time.Duration Timeout time.Duration
WgConfig WgConfig WgConfig WgConfig
@@ -89,22 +83,23 @@ type ConnConfig struct {
} }
type Conn struct { type Conn struct {
log *log.Entry Log *log.Entry
mu sync.Mutex mu sync.Mutex
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
config ConnConfig config ConnConfig
statusRecorder *Status statusRecorder *Status
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager relayManager *relayClient.Manager
handshaker *Handshaker srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string) onDisconnected func(remotePeer string)
statusRelay *AtomicConnStatus statusRelay *worker.AtomicWorkerStatus
statusICE *AtomicConnStatus statusICE *worker.AtomicWorkerStatus
currentConnPriority ConnPriority currentConnPriority conntype.ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection opened bool // this flag is used to prevent close in case of not opened connection
workerICE *WorkerICE workerICE *WorkerICE
@@ -120,9 +115,12 @@ type Conn struct {
wgProxyICE wgproxy.Proxy wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup semaphore *semaphoregroup.SemaphoreGroup
peerConnDispatcher *dispatcher.ConnectionDispatcher
wg sync.WaitGroup
// debug purpose // debug purpose
dumpState *stateDump dumpState *stateDump
@@ -130,91 +128,101 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
if len(config.WgConfig.AllowedIps) == 0 { if len(config.WgConfig.AllowedIps) == 0 {
return nil, fmt.Errorf("allowed IPs is empty") return nil, fmt.Errorf("allowed IPs is empty")
} }
ctx, ctxCancel := context.WithCancel(engineCtx)
connLog := log.WithField("peer", config.Key) connLog := log.WithField("peer", config.Key)
var conn = &Conn{ var conn = &Conn{
log: connLog, Log: connLog,
ctx: ctx, config: config,
ctxCancel: ctxCancel, statusRecorder: services.StatusRecorder,
config: config, signaler: services.Signaler,
statusRecorder: statusRecorder, iFaceDiscover: services.IFaceDiscover,
signaler: signaler, relayManager: services.RelayManager,
relayManager: relayManager, srWatcher: services.SrWatcher,
statusRelay: NewAtomicConnStatus(), semaphore: services.Semaphore,
statusICE: NewAtomicConnStatus(), peerConnDispatcher: services.PeerConnDispatcher,
semaphore: semaphore, statusRelay: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, statusRecorder), statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
} }
ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil {
return nil, err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
go conn.handshaker.Listen()
go conn.dumpState.Start(ctx)
return conn, nil return conn, nil
} }
// Open opens connection to the remote peer // Open opens connection to the remote peer
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used. // be used.
func (conn *Conn) Open() { func (conn *Conn) Open(engineCtx context.Context) error {
conn.semaphore.Add(conn.ctx) conn.semaphore.Add(engineCtx)
conn.log.Debugf("open connection to peer")
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.opened = true
if conn.opened {
conn.semaphore.Done(engineCtx)
return nil
}
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.handshaker.Listen(conn.ctx)
}()
go conn.dumpState.Start(conn.ctx)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
ConnStatus: StatusDisconnected, ConnStatus: StatusConnecting,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
err := conn.statusRecorder.UpdatePeerState(peerState) if err := conn.statusRecorder.UpdatePeerState(peerState); err != nil {
if err != nil { conn.Log.Warnf("error while updating the state err: %v", err)
conn.log.Warnf("error while updating the state err: %v", err)
} }
go conn.startHandshakeAndReconnect(conn.ctx) conn.wg.Add(1)
} go func() {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx)
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { conn.dumpState.SendOffer()
defer conn.semaphore.Done(conn.ctx) if err := conn.handshaker.sendOffer(); err != nil {
conn.waitInitialRandomSleepTime(ctx) conn.Log.Errorf("failed to send initial offer: %v", err)
}
conn.dumpState.SendOffer() conn.wg.Add(1)
err := conn.handshaker.sendOffer() go func() {
if err != nil { conn.guard.Start(conn.ctx, conn.onGuardEvent)
conn.log.Errorf("failed to send initial offer: %v", err) conn.wg.Done()
} }()
}()
go conn.guard.Start(ctx) conn.opened = true
go conn.listenGuardEvent(ctx) return nil
} }
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
@@ -223,14 +231,14 @@ func (conn *Conn) Close() {
defer conn.wgWatcherWg.Wait() defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.log.Infof("close peer connection")
conn.ctxCancel()
if !conn.opened { if !conn.opened {
conn.log.Debugf("ignore close connection to peer") conn.Log.Debugf("ignore close connection to peer")
return return
} }
conn.Log.Infof("close peer connection")
conn.ctxCancel()
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
conn.workerRelay.CloseConn() conn.workerRelay.CloseConn()
conn.workerICE.Close() conn.workerICE.Close()
@@ -238,7 +246,7 @@ func (conn *Conn) Close() {
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn() err := conn.wgProxyRelay.CloseConn()
if err != nil { if err != nil {
conn.log.Errorf("failed to close wg proxy for relay: %v", err) conn.Log.Errorf("failed to close wg proxy for relay: %v", err)
} }
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
@@ -246,13 +254,13 @@ func (conn *Conn) Close() {
if conn.wgProxyICE != nil { if conn.wgProxyICE != nil {
err := conn.wgProxyICE.CloseConn() err := conn.wgProxyICE.CloseConn()
if err != nil { if err != nil {
conn.log.Errorf("failed to close wg proxy for ice: %v", err) conn.Log.Errorf("failed to close wg proxy for ice: %v", err)
} }
conn.wgProxyICE = nil conn.wgProxyICE = nil
} }
if err := conn.removeWgPeer(); err != nil { if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.freeUpConnID() conn.freeUpConnID()
@@ -262,14 +270,16 @@ func (conn *Conn) Close() {
} }
conn.setStatusToDisconnected() conn.setStatusToDisconnected()
conn.log.Infof("peer connection has been closed") conn.opened = false
conn.wg.Wait()
conn.Log.Infof("peer connection closed")
} }
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
conn.dumpState.RemoteAnswer() conn.dumpState.RemoteAnswer()
conn.log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer) return conn.handshaker.OnRemoteAnswer(answer)
} }
@@ -298,7 +308,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
conn.dumpState.RemoteOffer() conn.dumpState.RemoteOffer()
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer) return conn.handshaker.OnRemoteOffer(offer)
} }
@@ -307,19 +317,24 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig return conn.config.WgConfig
} }
// Status returns current status of the Conn // IsConnected unit tests only
func (conn *Conn) Status() ConnStatus { // refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
func (conn *Conn) IsConnected() bool {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
return conn.evalStatus() return conn.currentConnPriority != conntype.None
} }
func (conn *Conn) GetKey() string { func (conn *Conn) GetKey() string {
return conn.config.Key return conn.config.Key
} }
func (conn *Conn) ConnID() id.ConnID {
return id.ConnID(conn)
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -327,21 +342,21 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
return return
} }
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) { if remoteConnNil(conn.Log, iceConnInfo.RemoteConn) {
conn.log.Errorf("remote ICE connection is nil") conn.Log.Errorf("remote ICE connection is nil")
return return
} }
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade // this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
// todo consider to remove this check // todo consider to remove this check
if conn.currentConnPriority > priority { if conn.currentConnPriority > priority {
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority) conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
conn.statusICE.Set(StatusConnected) conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
return return
} }
conn.log.Infof("set ICE to active connection") conn.Log.Infof("set ICE to active connection")
conn.dumpState.P2PConnected() conn.dumpState.P2PConnected()
var ( var (
@@ -353,7 +368,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return return
} }
ep = wgProxy.EndpointAddr() ep = wgProxy.EndpointAddr()
@@ -369,7 +384,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
} }
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err) conn.Log.Errorf("Before add peer hook failed: %v", err)
} }
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
@@ -388,10 +403,16 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
oldState := conn.currentConnPriority
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected) conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
if oldState == conntype.None {
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
} }
func (conn *Conn) onICEStateDisconnected() { func (conn *Conn) onICEStateDisconnected() {
@@ -402,22 +423,22 @@ func (conn *Conn) onICEStateDisconnected() {
return return
} }
conn.log.Tracef("ICE connection state changed to disconnected") conn.Log.Tracef("ICE connection state changed to disconnected")
if conn.wgProxyICE != nil { if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil { if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
} }
} }
// switch back to relay connection // switch back to relay connection
if conn.isReadyToUpgrade() { if conn.isReadyToUpgrade() {
conn.log.Infof("ICE disconnected, set Relay to active connection") conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay() conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.Log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.wgWatcherWg.Add(1) conn.wgWatcherWg.Add(1)
@@ -425,17 +446,18 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done() defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
}() }()
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = conntype.Relay
} else { } else {
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = connPriorityNone conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
} }
changed := conn.statusICE.Get() != StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
if changed { if changed {
conn.guard.SetICEConnDisconnected() conn.guard.SetICEConnDisconnected()
} }
conn.statusICE.Set(StatusDisconnected) conn.statusICE.SetDisconnected()
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -446,7 +468,7 @@ func (conn *Conn) onICEStateDisconnected() {
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState) err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err) conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
} }
} }
@@ -456,41 +478,41 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil { if err := rci.relayedConn.Close(); err != nil {
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) conn.Log.Warnf("failed to close unnecessary relayed connection: %v", err)
} }
return return
} }
conn.dumpState.RelayConnected() conn.dumpState.RelayConnected()
conn.log.Debugf("Relay connection has been established, setup the WireGuard") conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
wgProxy, err := conn.newProxy(rci.relayedConn) wgProxy, err := conn.newProxy(rci.relayedConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.isICEActive() { if conn.isICEActive() {
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected) conn.statusRelay.SetConnected()
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return return
} }
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err) conn.Log.Errorf("Before add peer hook failed: %v", err)
} }
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err) conn.Log.Warnf("Failed to close relay connection: %v", err)
} }
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return return
} }
@@ -502,12 +524,13 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround() wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = conntype.Relay
conn.statusRelay.Set(StatusConnected) conn.statusRelay.SetConnected()
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay") conn.Log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
} }
func (conn *Conn) onRelayDisconnected() { func (conn *Conn) onRelayDisconnected() {
@@ -518,14 +541,15 @@ func (conn *Conn) onRelayDisconnected() {
return return
} }
conn.log.Infof("relay connection is disconnected") conn.Log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay { if conn.currentConnPriority == conntype.Relay {
conn.log.Infof("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil { if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.currentConnPriority = connPriorityNone conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@@ -533,11 +557,11 @@ func (conn *Conn) onRelayDisconnected() {
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != worker.StatusDisconnected
if changed { if changed {
conn.guard.SetRelayedConnDisconnected() conn.guard.SetRelayedConnDisconnected()
} }
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.SetDisconnected()
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -546,22 +570,15 @@ func (conn *Conn) onRelayDisconnected() {
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
} }
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) conn.Log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
} }
} }
func (conn *Conn) listenGuardEvent(ctx context.Context) { func (conn *Conn) onGuardEvent() {
for { conn.Log.Debugf("send offer to peer")
select { conn.dumpState.SendOffer()
case <-conn.guard.Reconnect: if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Infof("send offer to peer") conn.Log.Errorf("failed to send offer: %v", err)
conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err)
}
case <-ctx.Done():
return
}
} }
} }
@@ -588,7 +605,7 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
err := conn.statusRecorder.UpdatePeerRelayedState(peerState) err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to save peer's Relay state, got error: %v", err) conn.Log.Warnf("unable to save peer's Relay state, got error: %v", err)
} }
} }
@@ -607,17 +624,18 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
err := conn.statusRecorder.UpdatePeerICEState(peerState) err := conn.statusRecorder.UpdatePeerICEState(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to save peer's ICE state, got error: %v", err) conn.Log.Warnf("unable to save peer's ICE state, got error: %v", err)
} }
} }
func (conn *Conn) setStatusToDisconnected() { func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.SetDisconnected()
conn.statusICE.Set(StatusDisconnected) conn.statusICE.SetDisconnected()
conn.currentConnPriority = conntype.None
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: StatusDisconnected, ConnStatus: StatusIdle,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
@@ -625,10 +643,10 @@ func (conn *Conn) setStatusToDisconnected() {
if err != nil { if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available. // pretty common error because by that time Engine can already remove the peer and status won't be available.
// todo rethink status updates // todo rethink status updates
conn.log.Debugf("error while updating peer's state, err: %v", err) conn.Log.Debugf("error while updating peer's state, err: %v", err)
} }
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil { if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err) conn.Log.Debugf("failed to reset wireguard stats for peer: %s", err)
} }
} }
@@ -656,32 +674,24 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
} }
func (conn *Conn) isRelayed() bool { func (conn *Conn) isRelayed() bool {
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) { switch conn.currentConnPriority {
case conntype.Relay, conntype.ICETurn:
return true
default:
return false return false
} }
if conn.currentConnPriority == connPriorityICEP2P {
return false
}
return true
} }
func (conn *Conn) evalStatus() ConnStatus { func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected { if conn.statusRelay.Get() == worker.StatusConnected || conn.statusICE.Get() == worker.StatusConnected {
return StatusConnected return StatusConnected
} }
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting { return StatusConnecting
return StatusConnecting
}
return StatusDisconnected
} }
func (conn *Conn) isConnectedOnAllWay() (connected bool) { func (conn *Conn) isConnectedOnAllWay() (connected bool) {
conn.mu.Lock() // would be better to protect this with a mutex, but it could cause deadlock with Close function
defer conn.mu.Unlock()
defer func() { defer func() {
if !connected { if !connected {
@@ -689,12 +699,12 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
} }
}() }()
if conn.statusICE.Get() == StatusDisconnected { if conn.statusICE.Get() == worker.StatusDisconnected {
return false return false
} }
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() != StatusConnected { if conn.statusRelay.Get() == worker.StatusDisconnected {
return false return false
} }
} }
@@ -716,7 +726,7 @@ func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" { if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil { if err := hook(conn.connIDRelay); err != nil {
conn.log.Errorf("After remove peer hook failed: %v", err) conn.Log.Errorf("After remove peer hook failed: %v", err)
} }
} }
conn.connIDRelay = "" conn.connIDRelay = ""
@@ -725,7 +735,7 @@ func (conn *Conn) freeUpConnID() {
if conn.connIDICE != "" { if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil { if err := hook(conn.connIDICE); err != nil {
conn.log.Errorf("After remove peer hook failed: %v", err) conn.Log.Errorf("After remove peer hook failed: %v", err)
} }
} }
conn.connIDICE = "" conn.connIDICE = ""
@@ -733,7 +743,7 @@ func (conn *Conn) freeUpConnID() {
} }
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(), IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
Port: conn.config.WgConfig.WgListenPort, Port: conn.config.WgConfig.WgListenPort,
@@ -741,18 +751,18 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
wgProxy := conn.config.WgConfig.WgInterface.GetProxy() wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil { if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return nil, err return nil, err
} }
return wgProxy, nil return wgProxy, nil
} }
func (conn *Conn) isReadyToUpgrade() bool { func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
} }
func (conn *Conn) isICEActive() bool { func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
} }
func (conn *Conn) removeWgPeer() error { func (conn *Conn) removeWgPeer() error {
@@ -760,10 +770,10 @@ func (conn *Conn) removeWgPeer() error {
} }
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err) conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil { if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil { if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr) conn.Log.Warnf("Failed to close wg proxy: %v", ierr)
} }
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@@ -773,16 +783,16 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
func (conn *Conn) logTraceConnState() { func (conn *Conn) logTraceConnState() {
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) conn.Log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
} else { } else {
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE) conn.Log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
} }
} }
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil { if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
} }
} }
conn.wgProxyRelay = proxy conn.wgProxyRelay = proxy
@@ -793,6 +803,10 @@ func (conn *Conn) AllowedIP() netip.Addr {
return conn.config.WgConfig.AllowedIps[0].Addr() return conn.config.WgConfig.AllowedIps[0].Addr()
} }
func (conn *Conn) AgentVersionString() string {
return conn.config.AgentVersion
}
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key { func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
if conn.config.RosenpassConfig.PubKey == nil { if conn.config.RosenpassConfig.PubKey == nil {
return conn.config.WgConfig.PreSharedKey return conn.config.WgConfig.PreSharedKey
@@ -804,7 +818,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
determKey, err := conn.rosenpassDetermKey() determKey, err := conn.rosenpassDetermKey()
if err != nil { if err != nil {
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err) conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return conn.config.WgConfig.PreSharedKey return conn.config.WgConfig.PreSharedKey
} }

View File

@@ -1,58 +1,29 @@
package peer package peer
import ( import (
"sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
// StatusConnected indicate the peer is in connected state // StatusIdle indicate the peer is in disconnected state
StatusConnected ConnStatus = iota StatusIdle ConnStatus = iota
// StatusConnecting indicate the peer is in connecting state // StatusConnecting indicate the peer is in connecting state
StatusConnecting StatusConnecting
// StatusDisconnected indicate the peer is in disconnected state // StatusConnected indicate the peer is in connected state
StatusDisconnected StatusConnected
) )
// ConnStatus describe the status of a peer's connection // ConnStatus describe the status of a peer's connection
type ConnStatus int32 type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string { func (s ConnStatus) String() string {
switch s { switch s {
case StatusConnecting: case StatusConnecting:
return "Connecting" return "Connecting"
case StatusConnected: case StatusConnected:
return "Connected" return "Connected"
case StatusDisconnected: case StatusIdle:
return "Disconnected" return "Idle"
default: default:
log.Errorf("unknown status: %d", s) log.Errorf("unknown status: %d", s)
return "INVALID_PEER_CONNECTION_STATUS" return "INVALID_PEER_CONNECTION_STATUS"

View File

@@ -14,7 +14,7 @@ func TestConnStatus_String(t *testing.T) {
want string want string
}{ }{
{"StatusConnected", StatusConnected, "Connected"}, {"StatusConnected", StatusConnected, "Connected"},
{"StatusDisconnected", StatusDisconnected, "Disconnected"}, {"StatusIdle", StatusIdle, "Idle"},
{"StatusConnecting", StatusConnecting, "Connecting"}, {"StatusConnecting", StatusConnecting, "Connecting"},
} }
@@ -24,5 +24,4 @@ func TestConnStatus_String(t *testing.T) {
assert.Equal(t, got, table.want, "they should be equal") assert.Equal(t, got, table.want, "they should be equal")
}) })
} }
} }

View File

@@ -1,7 +1,6 @@
package peer package peer
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@@ -11,6 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
"github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -18,6 +18,8 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
var testDispatcher = dispatcher.NewConnectionDispatcher()
var connConf = ConnConfig{ var connConf = ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
@@ -48,7 +50,13 @@ func TestNewConn_interfaceFilter(t *testing.T) {
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
sd := ServiceDependencies{
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@@ -60,7 +68,13 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@@ -94,7 +108,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@@ -125,43 +145,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
tables := []struct {
name string
statusIce ConnStatus
statusRelay ConnStatus
want ConnStatus
}{
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
si := NewAtomicConnStatus()
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
})
}
}
func TestConn_presharedKey(t *testing.T) { func TestConn_presharedKey(t *testing.T) {
conn1 := Conn{ conn1 := Conn{

View File

@@ -0,0 +1,29 @@
package conntype
import (
"fmt"
)
const (
None ConnPriority = 0
Relay ConnPriority = 1
ICETurn ConnPriority = 2
ICEP2P ConnPriority = 3
)
type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case None:
return "None"
case Relay:
return "PriorityRelay"
case ICETurn:
return "PriorityICETurn"
case ICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}

View File

@@ -0,0 +1,52 @@
package dispatcher
import (
"sync"
"github.com/netbirdio/netbird/client/internal/peer/id"
)
type ConnectionListener struct {
OnConnected func(peerID id.ConnID)
OnDisconnected func(peerID id.ConnID)
}
type ConnectionDispatcher struct {
listeners map[*ConnectionListener]struct{}
mu sync.Mutex
}
func NewConnectionDispatcher() *ConnectionDispatcher {
return &ConnectionDispatcher{
listeners: make(map[*ConnectionListener]struct{}),
}
}
func (e *ConnectionDispatcher) AddListener(listener *ConnectionListener) {
e.mu.Lock()
defer e.mu.Unlock()
e.listeners[listener] = struct{}{}
}
func (e *ConnectionDispatcher) RemoveListener(listener *ConnectionListener) {
e.mu.Lock()
defer e.mu.Unlock()
delete(e.listeners, listener)
}
func (e *ConnectionDispatcher) NotifyConnected(peerConnID id.ConnID) {
e.mu.Lock()
defer e.mu.Unlock()
for listener := range e.listeners {
listener.OnConnected(peerConnID)
}
}
func (e *ConnectionDispatcher) NotifyDisconnected(peerConnID id.ConnID) {
e.mu.Lock()
defer e.mu.Unlock()
for listener := range e.listeners {
listener.OnDisconnected(peerConnID)
}
}

Some files were not shown because too many files have changed in this diff Show More