Compare commits

...

38 Commits

Author SHA1 Message Date
Zoltan Papp
802a18167c [client] Do not reconnect to mgm server in case of handler error (#3856)
* Do not reconnect to mgm server in case of handler error
Set to nil the flow grpc client to nil

* Better error handling
2025-05-21 20:18:21 +02:00
hakansa
e9108ffe6c [client] Add latest gzipped rotated log file to the debug bundle (#3848)
[client] Add latest gzipped rotated log file to the debug bundle
2025-05-21 17:50:54 +03:00
Viktor Liu
e806d9de38 [client] Fix legacy routes when connecting to management servers older than v0.30.0 (#3854) 2025-05-21 13:48:55 +02:00
Zoltan Papp
daa8380df9 [client] Feature/lazy connection (#3379)
With the lazy connection feature, the peer will connect to target peers on-demand. The trigger can be any IP traffic.

This feature can be enabled with the NB_ENABLE_EXPERIMENTAL_LAZY_CONN environment variable.

When the engine receives a network map, it binds a free UDP port for every remote peer, and the system configures WireGuard endpoints for these ports. When traffic appears on a UDP socket, the system removes this listener and starts the peer connection procedure immediately.

Key changes
Fix slow netbird status -d command
Move from engine.go file to conn_mgr.go the peer connection related code
Refactor the iface interface usage and moved interface file next to the engine code
Add new command line flag and UI option to enable feature
The peer.Conn struct is reusable after it has been closed.
Change connection states
Connection states
Idle: The peer is not attempting to establish a connection. This typically means it's in a lazy state or the remote peer is expired.

Connecting: The peer is actively trying to establish a connection. This occurs when the peer has entered an active state and is continuously attempting to reach the remote peer.

Connected: A successful peer-to-peer connection has been established and communication is active.
2025-05-21 11:12:28 +02:00
Bethuel Mmbaga
4785f23fc4 [management] Migrate events sqlite store to gorm (#3837) 2025-05-20 17:00:37 +03:00
Viktor Liu
1d4cfb83e7 [client] Fix UI new version notifier (#3845) 2025-05-20 10:39:17 +02:00
Pascal Fischer
207fa059d2 [management] make locking strength clause optional (#3844) 2025-05-19 16:42:47 +02:00
Viktor Liu
cbcdad7814 [misc] Update issue template (#3842) 2025-05-19 15:41:24 +02:00
Pascal Fischer
701c13807a [management] add flag to disable auto-migration (#3840) 2025-05-19 13:36:24 +02:00
Viktor Liu
99f8dc7748 [client] Offer to remove netbird data in windows uninstall (#3766) 2025-05-16 17:39:30 +02:00
Pascal Fischer
f1de8e6eb0 [management] Make startup period configurable (#3767) 2025-05-16 13:16:51 +02:00
Viktor Liu
b2a10780af [client] Disable dnssec for systemd explicitly (#3831) 2025-05-16 09:43:13 +02:00
Pascal Fischer
43ae79d848 [management] extend rest client lib (#3830) 2025-05-15 18:20:29 +02:00
Pascal Fischer
e520b64c6d [signal] remove stream receive server side (#3820) 2025-05-14 19:28:51 +02:00
hakansa
92c91bbdd8 [client] Add FreeBSD desktop client support to OAuth flow (#3822)
[client] Add FreeBSD desktop client support to OAuth flow
2025-05-14 19:52:02 +03:00
Vlad
adf494e1ac [management] fix a bug with missed extra dns labels for a new peer (#3798) 2025-05-14 17:50:21 +02:00
Vlad
2158461121 [management,client] PKCE add flag parameter prompt=login or max_age (#3824) 2025-05-14 17:48:51 +02:00
Bethuel Mmbaga
0cd4b601c3 [management] Add connection type filter to Network Traffic API (#3815) 2025-05-14 11:15:50 +03:00
Zoltan Papp
ee1cec47b3 [client, android] Do not propagate empty routes (#3805)
If we get domain routes the Network prefix variable in route structure will be invalid (engine.go:1057). When we handower to Android the routes, we must to filter out the domain routes. If we do not do it the Android code will get "invalid prefix" string as a route.
2025-05-13 15:21:06 +02:00
Pascal Fischer
efb0edfc4c [signal] adjust signal log levels 2 (#3817) 2025-05-12 23:52:29 +02:00
Pascal Fischer
20f59ddecb [signal] adjust log levels (#3813) 2025-05-12 19:48:47 +02:00
hakansa
2f34e984b0 [client] Add TCP support to DNS forwarder service listener (#3790)
[client] Add TCP support to DNS forwarder service listener
2025-05-09 15:06:34 +03:00
Viktor Liu
d5b52e86b6 [client] Ignore irrelevant route changes to tracked network monitor routes (#3796) 2025-05-09 14:01:21 +02:00
Zoltan Papp
cad2fe1f39 Return with the correct copied length (#3804) 2025-05-09 13:56:27 +02:00
Pascal Fischer
fcd2c15a37 [management] policy delete cleans policy rules (#3788) 2025-05-07 07:25:25 +02:00
Bethuel Mmbaga
ebda0fc538 [management] Delete service users with account manager (#3793) 2025-05-06 17:31:03 +02:00
M. Essam
ac135ab11d [management/client/rest] fix panic when body is nil (#3714)
Fixes panic occurring when body is nil (this usually happens when connections is refused) due to lack of nil check by centralizing response.Body.Close() behavior.
2025-05-05 18:54:47 +02:00
Pascal Fischer
25faf9283d [management] removal of foreign key constraint enforcement on sqlite (#3786) 2025-05-05 18:21:48 +02:00
hakansa
59faaa99f6 [client] Improve NetBird installation script to handle daemon connection timeout (#3761)
[client] Improve NetBird installation script to handle daemon connection timeout
2025-05-05 17:05:01 +03:00
Viktor Liu
9762b39f29 [client] Fix stale local records (#3776) 2025-05-05 14:29:05 +02:00
Alin Trăistaru
ffdd115ded [client] set TLS ServerName for hostname-based QUIC connections (#3673)
* fix: set TLS ServerName for hostname-based QUIC connections

When connecting to a relay server by hostname, certificates are
validated against the IP address instead of the hostname.
This change sets ServerName in the TLS config when connecting
via hostname, ensuring proper certificate validation.

* use default port if port is missing in URL string
2025-05-05 12:20:54 +02:00
Pascal Fischer
055df9854c [management] add gorm tag for primary key for the networks objects (#3758) 2025-05-04 20:58:04 +02:00
Maycon Santos
12f883badf [management] Optimize load account (#3774) 2025-05-02 00:59:41 +02:00
Maycon Santos
2abb92b0d4 [management] Get account id with order (#3773)
updated log to display account id
2025-05-02 00:25:46 +02:00
Viktor Liu
01c3719c5d [client] Add debug for duration option to netbird ui (#3772) 2025-05-01 23:25:27 +02:00
Pedro Maia Costa
7b64953eed [management] user info with role permissions (#3728) 2025-05-01 11:24:55 +01:00
Viktor Liu
9bc7d788f0 [client] Add debug upload option to netbird ui (#3768) 2025-05-01 00:48:31 +02:00
Pedro Maia Costa
b5419ef11a [management] limit peers based on module read permission (#3757) 2025-04-30 15:53:18 +01:00
165 changed files with 7857 additions and 3323 deletions

View File

@@ -37,16 +37,21 @@ If yes, which one?
**Debug output** **Debug output**
To help us resolve the problem, please attach the following debug output To help us resolve the problem, please attach the following anonymized status output
netbird status -dA netbird status -dA
As well as the file created by Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?** **Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions - [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones) - [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client - [ ] Restarted the NetBird client
- [ ] Disabled other VPN software - [ ] Disabled other VPN software
- [ ] Checked firewall settings - [ ] Checked firewall settings

View File

@@ -179,6 +179,7 @@ jobs:
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445" grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"' grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true' grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy

View File

@@ -235,13 +235,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown { if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"runtime"
"strings" "strings"
"time" "time"
@@ -98,11 +99,11 @@ var loginCmd = &cobra.Command{
} }
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq, DnsLabels: dnsLabelsReq,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -195,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop()) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -243,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
} }
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
} }

View File

@@ -26,23 +26,23 @@ import (
) )
const ( const (
externalIPMapFlag = "external-ip-map" externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address" dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass" enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive" rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key" preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name" interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port" wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor" networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect" disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access" blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle" enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundleURL = "upload-bundle-url" uploadBundle = "upload-bundle"
defaultBundleURL = "https://upload.debug.netbird.io" + types.GetURLPath uploadBundleURL = "upload-bundle-url"
) )
var ( var (
@@ -81,6 +81,7 @@ var (
blockLANAccess bool blockLANAccess bool
debugUploadBundle bool debugUploadBundle bool
debugUploadBundleURL string debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
@@ -185,10 +186,11 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, defaultBundleURL, "Service URL to get an URL to upload the debug bundle") debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -44,7 +44,7 @@ func init() {
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected": case "", "idle", "connecting", "connected":
if strings.ToLower(statusFilter) != "" { if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
default: default:
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter) return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
} }
if len(ipsFilter) > 0 { if len(ipsFilter) > 0 {

View File

@@ -194,6 +194,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.BlockLANAccess = &blockLANAccess ic.BlockLANAccess = &blockLANAccess
} }
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
if err != nil { if err != nil {
return err return err
@@ -262,17 +266,17 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
} }
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels, DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0, CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -332,6 +336,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.BlockLanAccess = &blockLANAccess loginRequest.BlockLanAccess = &blockLANAccess
} }
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse

View File

@@ -201,14 +201,30 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() { func (c *KernelConfigurer) Close() {
} }
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) { func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey) stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) return nil, fmt.Errorf("wgctl: %w", err)
} }
return WGStats{ defer func() {
LastHandshake: peer.LastHandshakeTime, err = wg.Close()
TxBytes: peer.TransmitBytes, if err != nil {
RxBytes: peer.ReceiveBytes, log.Errorf("Got error while closing wgctl: %v", err)
}, nil }
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}
}
return stats, nil
} }

View File

@@ -1,6 +1,7 @@
package configurer package configurer
import ( import (
"encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
@@ -17,6 +18,13 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const (
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct { type WGUSPConfigurer struct {
@@ -217,91 +225,75 @@ func (t *WGUSPConfigurer) Close() {
} }
} }
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) { func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
ipc, err := t.device.IpcGet() ipc, err := t.device.IpcGet()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err) return nil, fmt.Errorf("ipc get: %w", err)
} }
stats, err := findPeerInfo(ipc, peerKey, []string{ return parseTransfers(ipc)
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
} }
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) { func parseTransfers(ipc string) (map[string]WGStats, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) stats := make(map[string]WGStats)
if err != nil { var (
return nil, fmt.Errorf("parse key: %w", err) currentKey string
} currentStats WGStats
hasPeer bool
hexKey := hex.EncodeToString(peerKeyParsed[:]) )
lines := strings.Split(ipc, "\n")
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key, // If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop. // this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") && foundPeer { if strings.HasPrefix(line, "public_key=") {
break peerID := strings.TrimPrefix(line, "public_key=")
} h, err := hex.DecodeString(peerID)
if err != nil {
// Identify the peer with the specific public key return nil, fmt.Errorf("decode peerID: %w", err)
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
for _, key := range searchConfigKeys {
if foundPeer && strings.HasPrefix(line, key+"=") {
v := strings.SplitN(line, "=", 2)
configFound[v[0]] = v[1]
} }
currentKey = base64.StdEncoding.EncodeToString(h)
currentStats = WGStats{} // Reset stats for the new peer
hasPeer = true
stats[currentKey] = currentStats
continue
}
if !hasPeer {
continue
}
key := strings.SplitN(line, "=", 2)
if len(key) != 2 {
continue
}
switch key[0] {
case ipcKeyLastHandshakeTimeSec:
hs, err := toLastHandshake(key[1])
if err != nil {
return nil, err
}
currentStats.LastHandshake = hs
stats[currentKey] = currentStats
case ipcKeyRxBytes:
rxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
currentStats.RxBytes = rxBytes
stats[currentKey] = currentStats
case ipcKeyTxBytes:
TxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
currentStats.TxBytes = TxBytes
stats[currentKey] = currentStats
} }
} }
// todo: use multierr return stats, nil
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
} }
func toWgUserspaceString(wgCfg wgtypes.Config) string { func toWgUserspaceString(wgCfg wgtypes.Config) string {
@@ -355,6 +347,18 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
return sb.String() return sb.String()
} }
func toLastHandshake(stringVar string) (time.Time, error) {
sec, err := strconv.ParseInt(stringVar, 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
return time.Unix(sec, 0), nil
}
func toBytes(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark return nbnet.ControlPlaneMark

View File

@@ -2,10 +2,8 @@ package configurer
import ( import (
"encoding/hex" "encoding/hex"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -34,58 +32,35 @@ errno=0
` `
func Test_findPeerInfo(t *testing.T) { func Test_parseTransfers(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
peerKey string peerKey string
searchKeys []string want WGStats
want map[string]string
wantErr bool
}{ }{
{ {
name: "single", name: "single",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
searchKeys: []string{"tx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 0,
"tx_bytes": "38333", RxBytes: 0,
}, },
wantErr: false,
}, },
{ {
name: "multiple", name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes", "rx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 38333,
"tx_bytes": "38333", RxBytes: 2224,
"rx_bytes": "2224",
}, },
wantErr: false,
}, },
{ {
name: "lastpeer", name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58", peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "rx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 1212111,
"tx_bytes": "1212111", RxBytes: 1929999999,
"rx_bytes": "1929999999",
}, },
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
key, err := wgtypes.NewKey(res) key, err := wgtypes.NewKey(res)
require.NoError(t, err) require.NoError(t, err)
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys) stats, err := parseTransfers(ipcFixture)
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)) if err != nil {
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys) require.NoError(t, err)
return
}
stat, ok := stats[key.String()]
if !ok {
require.True(t, ok)
return
}
require.Equal(t, tt.want, stat)
}) })
} }
} }

View File

@@ -16,5 +16,5 @@ type WGConfigurer interface {
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error
Close() Close()
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
} }

View File

@@ -212,9 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device() return w.tun.Device()
} }
// GetStats returns the last handshake time, rx and tx bytes for the given peer // GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats(peerKey) return w.configurer.GetStats()
} }
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {

View File

@@ -24,6 +24,8 @@
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run" !define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -49,6 +51,10 @@ ShowInstDetails Show
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!include "nsDialogs.nsh"
!define MUI_ICON "${ICON}" !define MUI_ICON "${ICON}"
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
@@ -58,9 +64,6 @@ ShowInstDetails Show
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
!define MUI_UNABORTWARNING !define MUI_UNABORTWARNING
@@ -70,13 +73,16 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
!insertmacro MUI_UNPAGE_WELCOME
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
!insertmacro MUI_UNPAGE_CONFIRM !insertmacro MUI_UNPAGE_CONFIRM
!insertmacro MUI_UNPAGE_INSTFILES !insertmacro MUI_UNPAGE_INSTFILES
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
Var AutostartCheckbox Var AutostartCheckbox
Var AutostartEnabled Var AutostartEnabled
; Variables for uninstall data deletion option
Var DeleteDataCheckbox
Var DeleteDataEnabled
###################################################################### ######################################################################
; Function to create the autostart options page ; Function to create the autostart options page
@@ -104,8 +114,8 @@ Function AutostartPage
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts" ${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked ${NSD_Check} $AutostartCheckbox
StrCpy $AutostartEnabled "1" ; Default to enabled StrCpy $AutostartEnabled "1"
nsDialogs::Show nsDialogs::Show
FunctionEnd FunctionEnd
@@ -115,6 +125,30 @@ Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled ${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd FunctionEnd
; Function to create the uninstall data deletion page
Function un.DeleteDataPage
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
Pop $DeleteDataCheckbox
${NSD_Uncheck} $DeleteDataCheckbox
StrCpy $DeleteDataEnabled "0"
nsDialogs::Show
FunctionEnd
; Function to handle leaving the data deletion page
Function un.DeleteDataPageLeave
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -176,10 +210,10 @@ ${EndIf}
FunctionEnd FunctionEnd
###################################################################### ######################################################################
Section -MainProgram Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
# SetOverwrite ifnewer # SetOverwrite ifnewer
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\" File /r "..\\dist\\netbird_windows_amd64\\"
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -225,31 +259,58 @@ SectionEnd
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
DetailPrint "Stopping Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
DetailPrint "Uninstalling Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."
${If} $DeleteDataEnabled == "1"
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
ClearErrors
RMDir /r "${NETBIRD_DATA_DIR}"
IfErrors 0 +2 ; If no errors, jump over the message
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
DetailPrint "Netbird data directory removal complete."
${Else}
DetailPrint "User did not opt to delete Netbird data."
${EndIf}
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
DetailPrint "Waiting for service handle to be released..."
Sleep 3000 Sleep 3000
DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
DetailPrint "Removing shortcuts..."
SetShellVarContext all SetShellVarContext all
Delete "$DESKTOP\${APP_NAME}.lnk" Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk" Delete "$SMPROGRAMS\${APP_NAME}.lnk"
DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM
EnVar::DeleteValue "path" "$INSTDIR" EnVar::DeleteValue "path" "$INSTDIR"
DetailPrint "Uninstallation finished."
SectionEnd SectionEnd

View File

@@ -76,12 +76,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)

View File

@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxDesktopClient { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }

View File

@@ -101,7 +101,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
} }
if !p.providerConfig.DisablePromptLogin { if !p.providerConfig.DisablePromptLogin {
params = append(params, oauth2.SetAuthURLParam("prompt", "login")) if p.providerConfig.LoginFlag.IsPromptLogin() {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
} }
authURL := p.oAuthConfig.AuthCodeURL(state, params...) authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@@ -7,15 +7,36 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/management/client/common"
) )
func TestPromptLogin(t *testing.T) { func TestPromptLogin(t *testing.T) {
const (
promptLogin = "prompt=login"
maxAge0 = "max_age=0"
)
tt := []struct { tt := []struct {
name string name string
prompt bool loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
}{ }{
{"PromptLogin", true}, {
{"NoPromptLogin", false}, name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
disablePromptLogin: true,
},
} }
for _, tc := range tt { for _, tc := range tt {
@@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize", AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"}, RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true, UseIDToken: true,
DisablePromptLogin: !tc.prompt, LoginFlag: tc.loginFlag,
} }
pkce, err := NewPKCEAuthorizationFlow(config) pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil { if err != nil {
@@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to request auth info: %v", err) t.Fatalf("Failed to request auth info: %v", err)
} }
pattern := "prompt=login"
if tc.prompt { if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, pattern) require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else { } else {
require.NotContains(t, authInfo.VerificationURIComplete, pattern) require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
} }
}) })
} }

View File

@@ -74,6 +74,8 @@ type ConfigInput struct {
DisableNotifications *bool DisableNotifications *bool
DNSLabels domain.List DNSLabels domain.List
LazyConnectionEnabled *bool
} }
// Config Configuration type // Config Configuration type
@@ -138,6 +140,8 @@ type Config struct {
ClientCertKeyPath string ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"` ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -524,6 +528,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
return updated, nil return updated, nil
} }

303
client/internal/conn_mgr.go Normal file
View File

@@ -0,0 +1,303 @@
package internal
import (
"context"
"os"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
//
// The connection manager is responsible for:
// - Managing lazy connections via the lazyConnManager
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
lazyConnMgr *manager.Manager
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
}
return e
}
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
func (e *ConnMgr) Start(ctx context.Context) {
if e.lazyConnMgr != nil {
log.Errorf("lazy connection manager is already started")
return
}
if !e.enabledLocally {
log.Infof("lazy connection manager is disabled")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
// do not disable lazy connection manager if it was enabled by env var
if e.enabledLocally {
return nil
}
if enabled {
// if the lazy connection manager is already started, do not start it again
if e.lazyConnMgr != nil {
return nil
}
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager(ctx)
} else {
if e.lazyConnMgr == nil {
return nil
}
log.Infof("lazy connection manager is disabled by management feature flag")
e.closeManager(ctx)
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(peerIDs []string) {
if e.lazyConnMgr == nil {
return
}
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
for _, peerID := range peerIDs {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
continue
}
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
if err := peerConn.Open(e.ctx); err != nil {
peerConn.Log.Errorf("failed to open connection: %v", err)
}
}
}
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
return true
}
if !e.isStartedWithLazyMgr() {
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if !lazyconn.IsSupported(conn.AgentVersionString()) {
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerKey,
AllowedIPs: conn.WgConfig().AllowedIps,
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if excluded {
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
conn.Log.Infof("peer added to lazy conn manager")
return
}
func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn, ok := e.peerStore.Remove(peerKey)
if !ok {
return
}
defer conn.Close()
if !e.isStartedWithLazyMgr() {
return
}
e.lazyConnMgr.RemovePeer(peerKey)
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if !e.isStartedWithLazyMgr() {
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(e.ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
return conn, true
}
func (e *ConnMgr) Close() {
if !e.isStartedWithLazyMgr() {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
}
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
ctx, cancel := context.WithCancel(parentCtx)
e.ctx = ctx
e.ctxCancel = cancel
e.wg.Add(1)
go func() {
defer e.wg.Done()
e.lazyConnMgr.Start(ctx)
}()
}
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
peers := e.peerStore.PeersPubKey()
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
for _, peerID := range peers {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {
if e.lazyConnMgr == nil {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
for _, peerID := range e.peerStore.PeersPubKey() {
e.peerStore.PeerConnOpen(ctx, peerID)
}
}
func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.ctxCancel != nil
}
func inactivityThresholdEnv() *time.Duration {
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
if envValue == "" {
return nil
}
parsedMinutes, err := strconv.Atoi(envValue)
if err != nil || parsedMinutes <= 0 {
return nil
}
d := time.Duration(parsedMinutes) * time.Minute
return &d
}

View File

@@ -440,7 +440,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableDNS: config.DisableDNS, DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall, DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess, BlockLANAccess: config.BlockLANAccess,
LazyConnectionEnabled: config.LazyConnectionEnabled,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -481,7 +482,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
return signalClient, nil return signalClient, nil
} }
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) { func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey() serverPublicKey, err := client.GetServerPublicKey()

View File

@@ -4,6 +4,7 @@ import (
"archive/zip" "archive/zip"
"bufio" "bufio"
"bytes" "bytes"
"compress/gzip"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -376,6 +377,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall)) configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess)) configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {
@@ -533,6 +535,33 @@ func (g *BundleGenerator) addLogfile() error {
return fmt.Errorf("add client log file to zip: %w", err) return fmt.Errorf("add client log file to zip: %w", err)
} }
// add latest rotated log file
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
} else if len(files) > 0 {
// pick the file with the latest ModTime
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().Before(fj.ModTime())
})
latest := files[len(files)-1]
name := filepath.Base(latest)
if err := g.addSingleLogFileGz(latest, name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
stdErrLogPath := filepath.Join(logDir, errorLogFile) stdErrLogPath := filepath.Join(logDir, errorLogFile)
stdoutLogPath := filepath.Join(logDir, stdoutLogFile) stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" {
@@ -563,16 +592,13 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
} }
}() }()
var logReader io.Reader var logReader io.Reader = logFile
if g.anonymize { if g.anonymize {
var writer *io.PipeWriter var writer *io.PipeWriter
logReader, writer = io.Pipe() logReader, writer = io.Pipe()
go anonymizeLog(logFile, writer, g.anonymizer) go anonymizeLog(logFile, writer, g.anonymizer)
} else {
logReader = logFile
} }
if err := g.addFileToZip(logReader, targetName); err != nil { if err := g.addFileToZip(logReader, targetName); err != nil {
return fmt.Errorf("add %s to zip: %w", targetName, err) return fmt.Errorf("add %s to zip: %w", targetName, err)
} }
@@ -580,6 +606,44 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
return nil return nil
} }
// addSingleLogFileGz adds a single gzipped log file to the archive
func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
f, err := os.Open(logPath)
if err != nil {
return fmt.Errorf("open gz log file %s: %w", targetName, err)
}
defer f.Close()
gzr, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("create gzip reader: %w", err)
}
defer gzr.Close()
var logReader io.Reader = gzr
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(gzr, pw, g.anonymizer)
}
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
if _, err := io.Copy(gw, logReader); err != nil {
return fmt.Errorf("re-gzip: %w", err)
}
if err := gw.Close(); err != nil {
return fmt.Errorf("close gzip writer: %w", err)
}
if err := g.addFileToZip(&buf, targetName); err != nil {
return fmt.Errorf("add anonymized gz: %w", err)
}
return nil
}
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error { func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
header := &zip.FileHeader{ header := &zip.FileHeader{
Name: filename, Name: filename,

View File

@@ -1,7 +1,6 @@
package dns_test package dns_test
import ( import (
"net"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -9,6 +8,7 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
) )
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
r.SetQuestion("example.com.", dns.TypeA) r.SetQuestion("example.com.", dns.TypeA)
// Create test writer // Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations - only highest priority handler should be called // Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
@@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
@@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request // Create and execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
// Verify expectations // Verify expectations
@@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
}).Once() }).Once()
// Execute // Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
// Verify all handlers were called in order // Verify all handlers were called in order
@@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3.AssertExpectations(t) handler3.AssertExpectations(t)
} }
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) { func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations // Setup expectations
for priority, handler := range handlers { for priority, handler := range handlers {
@@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state // Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Highest priority handler (routeHandler) should be called // Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
@@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 2: Remove highest priority handler // Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called // Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
@@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
@@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 4: Remove last handler // Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault) chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
for _, m := range mocks { for _, m := range mocks {
@@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
// Execute request // Execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r) chain.ServeDNS(&test.MockResponseWriter{}, r)
// Verify each handler was called exactly as expected // Verify each handler was called exactly as expected
for _, h := range tt.addHandlers { for _, h := range tt.addHandlers {
@@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup handler expectations // Setup handler expectations
for pattern, handler := range handlers { for pattern, handler := range handlers {
@@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
handler := &nbdns.MockHandler{} handler := &nbdns.MockHandler{}
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA) r.SetQuestion(tt.queryPattern, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// First verify no handler is called before adding any // First verify no handler is called before adding any
chain.ServeDNS(w, r) chain.ServeDNS(w, r)

View File

@@ -1,130 +0,0 @@
package dns
import (
"fmt"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
type registrationMap map[string]struct{}
type localResolver struct {
registeredMap registrationMap
records sync.Map // key: string (domain_class_type), value: []dns.RR
}
func (d *localResolver) MatchSubdomains() bool {
return true
}
func (d *localResolver) stop() {
}
// String returns a string representation of the local resolver
func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}
// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 {
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(r)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
if len(r.Question) == 0 {
return nil
}
question := r.Question[0]
question.Name = strings.ToLower(question.Name)
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
value, found := d.records.Load(key)
if !found {
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
r.Question[0].Qtype = dns.TypeCNAME
return d.lookupRecords(r)
}
return nil
}
records, ok := value.([]dns.RR)
if !ok {
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
return nil
}
// if there's more than one record, rotate them (round-robin)
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records.Store(key, records)
}
return records
}
// registerRecord stores a new record by appending it to any existing list
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
rr, err := dns.NewRR(record.String())
if err != nil {
return "", fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
// load any existing slice of records, then append
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
records := existing.([]dns.RR)
records = append(records, rr)
// store updated slice
d.records.Store(key, records)
return key, nil
}
// deleteRecord removes *all* records under the recordKey.
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}
// buildRecordKey consistently generates a key: name_class_type
func buildRecordKey(name string, class, qType uint16) string {
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
}
func (d *localResolver) probeAvailability() {}

View File

@@ -0,0 +1,149 @@
package local
import (
"fmt"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
)
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
}
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
}
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
// String returns a string representation of the local resolver
func (d *Resolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// TODO: return success if we have a different record type for the same name, relevant for search domains
replyMessage.Rcode = dns.RcodeNameError
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock()
records, found := d.records[question]
if !found {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
}
return nil
}
recordsCopy := slices.Clone(records)
d.mu.RUnlock()
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
}
d.mu.Unlock()
}
return recordsCopy
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
}
}
}
// RegisterRecord stores a new record by appending it to any existing list
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
d.mu.Lock()
defer d.mu.Unlock()
return d.registerRecord(record)
}
// registerRecord performs the registration with the lock already held
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
rr, err := dns.NewRR(record.String())
if err != nil {
return fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
q := dns.Question{
Name: strings.ToLower(dns.Fqdn(header.Name)),
Qtype: header.Rrtype,
Qclass: header.Class,
}
d.records[q] = append(d.records[q], rr)
return nil
}

View File

@@ -0,0 +1,472 @@
package local
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := NewResolver()
_ = resolver.RegisterRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}
// TestLocalResolver_Update_StaleRecord verifies that updating
// a record correctly replaces the old one, preventing stale entries.
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
recordName := "host.example.com."
recordType := dns.TypeA
recordClass := dns.ClassINET
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
}
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
resolver := NewResolver()
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
// Apply first update
resolver.Update(update1)
// Verify first update
resolver.mu.RLock()
rrSlice1, found1 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found1, "Record key %s not found after first update", recordKey)
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(update2)
// Verify second update
resolver.mu.RLock()
rrSlice2, found2 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found2, "Record key %s not found after second update", recordKey)
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
}
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
// with the same question are stored properly
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
resolver := NewResolver()
recordName := "multi.example.com."
recordType := dns.TypeA
// Create two records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
update := []nbdns.SimpleRecord{record1, record2}
// Apply update with both records
resolver.Update(update)
// Create question that matches both records
question := dns.Question{
Name: recordName,
Qtype: recordType,
Qclass: dns.ClassINET,
}
// Verify both records are stored
resolver.mu.RLock()
records, found := resolver.records[question]
resolver.mu.RUnlock()
require.True(t, found, "Records for question %v not found", question)
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
// Verify both record data values are present
recordStrings := []string{records[0].String(), records[1].String()}
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
}
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
func TestLocalResolver_RecordRotation(t *testing.T) {
resolver := NewResolver()
recordName := "rotation.example.com."
recordType := dns.TypeA
// Create three records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
}
record3 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
update := []nbdns.SimpleRecord{record1, record2, record3}
// Apply update with all three records
resolver.Update(update)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
// First lookup - should return the records in original order
var responses [3]*dns.Msg
// Perform three lookups to verify rotation
for i := 0; i < 3; i++ {
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responses[i] = m
return nil
},
}
resolver.ServeDNS(responseWriter, msg)
}
// Verify all three responses contain answers
for i, resp := range responses {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
}
// Verify the first record in each response is different due to rotation
firstRecordIPs := []string{
responses[0].Answer[0].String(),
responses[1].Answer[0].String(),
responses[2].Answer[0].String(),
}
// Each record should be different (rotated)
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
// After three rotations, we should have cycled through all records
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
}
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
resolver := NewResolver()
// Create record with lowercase name
lowerCaseRecord := nbdns.SimpleRecord{
Name: "lower.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.10.10.10",
}
// Create record with mixed case name
mixedCaseRecord := nbdns.SimpleRecord{
Name: "MiXeD.ExAmPlE.CoM.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "20.20.20.20",
}
// Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
testCases := []struct {
name string
queryName string
expectedRData string
shouldResolve bool
}{
{
name: "Query lowercase with lowercase record",
queryName: "lower.example.com.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query uppercase with lowercase record",
queryName: "LOWER.EXAMPLE.COM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query mixed case with lowercase record",
queryName: "LoWeR.eXaMpLe.CoM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query lowercase with mixed case record",
queryName: "mixed.example.com.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query uppercase with mixed case record",
queryName: "MIXED.EXAMPLE.COM.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query with different casing pattern",
queryName: "mIxEd.ExaMpLe.cOm.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query non-existent domain",
queryName: "nonexistent.example.com.",
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case name
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 {
// Expected no answer, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected IP address %s, got: %s",
tc.expectedRData, answerString)
})
}
}
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
// to checking for CNAME records when the requested record type isn't found
func TestLocalResolver_CNAMEFallback(t *testing.T) {
resolver := NewResolver()
// Create a CNAME record (but no A record for this name)
cnameRecord := nbdns.SimpleRecord{
Name: "alias.example.com.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "target.example.com.",
}
// Create an A record for the CNAME target
targetRecord := nbdns.SimpleRecord{
Name: "target.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.100.100",
}
// Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
testCases := []struct {
name string
queryName string
queryType uint16
expectedType string
expectedRData string
shouldResolve bool
}{
{
name: "Directly query CNAME record",
queryName: "alias.example.com.",
queryType: dns.TypeCNAME,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query A record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query AAAA record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeAAAA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query direct A record",
queryName: "target.example.com.",
queryType: dns.TypeA,
expectedType: "A",
expectedRData: "192.168.100.100",
shouldResolve: true,
},
{
name: "Query non-existent name",
queryName: "nonexistent.example.com.",
queryType: dns.TypeA,
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case parameters
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
// Expected no resolution, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a successful response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedType,
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
})
}
}

View File

@@ -1,88 +0,0 @@
package dns
import (
"strings"
"testing"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_, _ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@@ -1,26 +0,0 @@
package dns
import (
"net"
"github.com/miekg/dns"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@@ -15,6 +15,8 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -46,8 +48,6 @@ type Server interface {
ProbeAvailability() ProbeAvailability()
} }
type handlerID string
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
domain string domain string
groups []*nbdns.NameServerGroup groups []*nbdns.NameServerGroup
@@ -61,7 +61,7 @@ type DefaultServer struct {
mux sync.Mutex mux sync.Mutex
service service service service
dnsMuxMap registeredHandlerMap dnsMuxMap registeredHandlerMap
localResolver *localResolver localResolver *local.Resolver
wgInterface WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
updateSerial uint64 updateSerial uint64
@@ -84,9 +84,9 @@ type DefaultServer struct {
type handlerWithStop interface { type handlerWithStop interface {
dns.Handler dns.Handler
stop() Stop()
probeAvailability() ProbeAvailability()
id() handlerID ID() types.HandlerID
} }
type handlerWrapper struct { type handlerWrapper struct {
@@ -95,7 +95,7 @@ type handlerWrapper struct {
priority int priority int
} }
type registeredHandlerMap map[handlerID]handlerWrapper type registeredHandlerMap map[types.HandlerID]handlerWrapper
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(
@@ -171,16 +171,14 @@ func newDefaultServer(
handlerChain := NewHandlerChain() handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
disableSys: disableSys, disableSys: disableSys,
service: dnsService, service: dnsService,
handlerChain: handlerChain, handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{ localResolver: local.NewResolver(),
registeredMap: make(registrationMap),
},
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager, stateManager: stateManager,
@@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Add(1) wg.Add(1)
go func(mux handlerWithStop) { go func(mux handlerWithStop) {
defer wg.Done() defer wg.Done()
mux.probeAvailability() mux.ProbeAvailability()
}(mux.handler) }(mux.handler)
} }
wg.Wait() wg.Wait()
@@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.service.Stop() s.service.Stop()
} }
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones) localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil { if err != nil {
return fmt.Errorf("local handler updater: %w", err) return fmt.Errorf("local handler updater: %w", err)
} }
@@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates) s.updateMux(muxUpdates)
// register local records // register local records
s.updateLocalResolver(localRecordsByDomain) s.localResolver.Update(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) {
) )
} }
func (s *DefaultServer) buildLocalHandlerUpdate( func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
customZones []nbdns.CustomZone,
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
var muxUpdates []handlerWrapper var muxUpdates []handlerWrapper
localRecords := make(map[string][]nbdns.SimpleRecord) var localRecords []nbdns.SimpleRecord
for _, customZone := range customZones { for _, customZone := range customZones {
if len(customZone.Records) == 0 { if len(customZone.Records) == 0 {
@@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}) })
// group all records under this domain
for _, record := range customZone.Records { for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass { if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class) log.Warnf("received an invalid class type: %s", record.Class)
continue continue
} }
// zone records contain the fqdn, so we can just flatten them
key := buildRecordKey(record.Name, class, uint16(record.Type)) localRecords = append(localRecords, record)
localRecords[key] = append(localRecords[key], record)
} }
} }
@@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
} }
if len(handler.upstreamServers) == 0 { if len(handler.upstreamServers) == 0 {
handler.stop() handler.Stop()
log.Errorf("received a nameserver group with an invalid nameserver list") log.Errorf("received a nameserver group with an invalid nameserver list")
continue continue
} }
@@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests // this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap { for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority) s.deregisterHandler([]string{existing.domain}, existing.priority)
existing.handler.stop() existing.handler.Stop()
} }
muxUpdateMap := make(registeredHandlerMap) muxUpdateMap := make(registeredHandlerMap)
@@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
containsRootUpdate = true containsRootUpdate = true
} }
s.registerHandler([]string{update.domain}, update.handler, update.priority) s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.id()] = update muxUpdateMap[update.handler.ID()] = update
} }
// If there's no root update and we had a root handler, restore it // If there's no root update and we had a root handler, restore it
@@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
} }
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
// remove old records that are no longer present
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
s.localResolver.deleteRecord(key)
}
}
updatedMap := make(registrationMap)
for _, recs := range update {
for _, rec := range recs {
// convert the record to a dns.RR and register
key, err := s.localResolver.registerRecord(rec)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v",
rec.String(), err)
continue
}
updatedMap[key] = struct{}{}
}
}
s.localResolver.registeredMap = updatedMap
}
func getNSHostPort(ns nbdns.NameServer) string { func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
} }

View File

@@ -23,6 +23,9 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe
} }
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("8.8.8.8"), IP: netip.MustParseAddr("8.8.8.8"),
@@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
} }
dummyHandler := &localResolver{} dummyHandler := local.NewResolver()
testCases := []struct { testCases := []struct {
name string name string
initUpstreamMap registeredHandlerMap initUpstreamMap registeredHandlerMap
initLocalMap registrationMap initLocalRecords []nbdns.SimpleRecord
initSerial uint64 initSerial uint64
inputSerial uint64 inputSerial uint64
inputUpdate nbdns.Config inputUpdate nbdns.Config
shouldFail bool shouldFail bool
expectedUpstreamMap registeredHandlerMap expectedUpstreamMap registeredHandlerMap
expectedLocalMap registrationMap expectedLocalQs []dns.Question
}{ }{
{ {
name: "Initial Config Should Succeed", name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
@@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{ expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
dummyHandler.id(): handlerWrapper{ dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
generateDummyHandler(".", nameServers).id(): handlerWrapper{ generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone, domain: nbdns.RootZone,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityDefault, priority: PriorityDefault,
}, },
}, },
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
}, },
{ {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: buildRecordKey(zoneRecords[0].Name, 1, 1), domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
@@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{ expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@@ -216,22 +219,22 @@ func TestUpdateDNSServer(t *testing.T) {
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
}, },
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
}, },
{ {
name: "Smaller Config Serial Should Be Skipped", name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap), initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 2, initSerial: 2,
inputSerial: 1, inputSerial: 1,
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Empty NS Group Domain Or Not Primary Element Should Fail", name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap), initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -249,11 +252,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Invalid NS Group Nameservers list Should Fail", name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap), initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -271,11 +274,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Invalid Custom Zone Records list Should Skip", name: "Invalid Custom Zone Records list Should Skip",
initLocalMap: make(registrationMap), initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -290,17 +293,17 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{ expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: ".", domain: ".",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityDefault, priority: PriorityDefault,
}}, }},
}, },
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) {
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true}, inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap), expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalMap: make(registrationMap), expectedLocalQs: []dns.Question{},
}, },
{ {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) {
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false}, inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap), expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalMap: make(registrationMap), expectedLocalQs: []dns.Question{},
}, },
} }
@@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) {
}() }()
dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.updateSerial = testCase.initSerial dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) {
} }
} }
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) { var responseMSG *dns.Msg
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap)) responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
} }
for key := range testCase.expectedLocalMap { if len(testCase.expectedLocalQs) > 0 {
_, found := dnsServer.localResolver.registeredMap[key] assert.NotNil(t, responseMSG, "response message should not be nil")
if !found { assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap) assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
} }
}) })
} }
@@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
dnsServer.dnsMuxMap = registeredHandlerMap{ dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: &localResolver{}, handler: &local.Resolver{},
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
} }
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.updateSerial = 0 dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
@@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) {
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop() defer dnsServer.Stop()
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0]) err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -630,13 +642,11 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
ctx: context.Background(), ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}), service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: local.NewResolver(),
registeredMap: make(registrationMap), handlerChain: NewHandlerChain(),
}, hostManager: hostManager,
handlerChain: NewHandlerChain(),
hostManager: hostManager,
currentConfig: HostDNSConfig{ currentConfig: HostDNSConfig{
Domains: []DomainConfig{ Domains: []DomainConfig{
{false, "domain0", false}, {false, "domain0", false},
@@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tc.query, dns.TypeA) r.SetQuestion(tc.query, dns.TypeA)
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
if mh, ok := tc.expectedHandler.(*MockHandler); ok { if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.On("ServeDNS", mock.Anything, r).Once() mh.On("ServeDNS", mock.Anything, r).Once()
@@ -1037,9 +1047,9 @@ type mockHandler struct {
} }
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) stop() {} func (m *mockHandler) Stop() {}
func (m *mockHandler) probeAvailability() {} func (m *mockHandler) ProbeAvailability() {}
func (m *mockHandler) id() handlerID { return handlerID(m.Id) } func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{} type mockService struct{}
@@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
name string name string
initialHandlers registeredHandlerMap initialHandlers registeredHandlerMap
updates []handlerWrapper updates []handlerWrapper
expectedHandlers map[string]string // map[handlerID]domain expectedHandlers map[string]string // map[HandlerID]domain
description string description string
}{ }{
{ {
@@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Check each expected handler // Check each expected handler
for id, expectedDomain := range tt.expectedHandlers { for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[handlerID(id)] handler, exists := server.dnsMuxMap[types.HandlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id) assert.True(t, exists, "Expected handler %s not found", id)
if exists { if exists {
assert.Equal(t, expectedDomain, handler.domain, assert.Equal(t, expectedDomain, handler.domain,
@@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
} }
// Verify no unexpected handlers exist // Verify no unexpected handlers exist
for handlerID := range server.dnsMuxMap { for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(handlerID)] _, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", handlerID) assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
} }
// Verify the handlerChain state and order // Verify the handlerChain state and order
@@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) {
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
wgInterface: &mocWGIface{}, wgInterface: &mocWGIface{},
hostManager: mockHostConfig, hostManager: mockHostConfig,
localResolver: &localResolver{}, localResolver: &local.Resolver{},
service: mockSvc, service: mockSvc,
statusRecorder: peer.NewRecorder("test"), statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
@@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: mockHostConfig, hostManager: mockHostConfig,
localResolver: &localResolver{}, localResolver: &local.Resolver{},
service: mockSvc, service: mockSvc,
statusRecorder: peer.NewRecorder("test"), statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
@@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: mockHostConfig, hostManager: mockHostConfig,
localResolver: &localResolver{}, localResolver: &local.Resolver{},
service: mockSvc, service: mockSvc,
statusRecorder: peer.NewRecorder("test"), statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
@@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: mockHostConfig, hostManager: mockHostConfig,
localResolver: &localResolver{}, localResolver: &local.Resolver{},
service: mockSvc, service: mockSvc,
statusRecorder: peer.NewRecorder("test"), statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),

View File

@@ -30,9 +30,12 @@ const (
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS" systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusResolvConfModeForeign = "foreign" systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
dnsSecDisabled = "no"
) )
type systemdDbusConfigurator struct { type systemdDbusConfigurator struct {
@@ -95,9 +98,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
Family: unix.AF_INET, Family: unix.AF_INET,
Address: ipAs4[:], Address: ipAs4[:],
} }
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}) if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
if err != nil { return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err) }
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSSEC to 'no': %v", err)
} }
var ( var (

View File

@@ -0,0 +1,26 @@
package test
import (
"net"
"github.com/miekg/dns"
)
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *MockResponseWriter) Close() error { return nil }
func (rw *MockResponseWriter) TsigStatus() error { return nil }
func (rw *MockResponseWriter) TsigTimersOnly(bool) {}
func (rw *MockResponseWriter) Hijack() {}

View File

@@ -0,0 +1,3 @@
package types
type HandlerID string

View File

@@ -19,6 +19,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string {
} }
// ID returns the unique handler ID // ID returns the unique handler ID
func (u *upstreamResolverBase) id() handlerID { func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers) servers := slices.Clone(u.upstreamServers)
slices.Sort(servers) slices.Sort(servers)
hash := sha256.New() hash := sha256.New()
hash.Write([]byte(u.domain + ":")) hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ","))) hash.Write([]byte(strings.Join(servers, ",")))
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
} }
func (u *upstreamResolverBase) MatchSubdomains() bool { func (u *upstreamResolverBase) MatchSubdomains() bool {
return true return true
} }
func (u *upstreamResolverBase) stop() { func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel() u.cancel()
} }
@@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
) )
} }
// probeAvailability tests all upstream servers simultaneously and // ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work // disables the resolver if none work
func (u *upstreamResolverBase) probeAvailability() { func (u *upstreamResolverBase) ProbeAvailability() {
u.mutex.Lock() u.mutex.Lock()
defer u.mutex.Unlock() defer u.mutex.Unlock()

View File

@@ -8,6 +8,8 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
) )
func TestUpstreamResolver_ServeDNS(t *testing.T) { func TestUpstreamResolver_ServeDNS(t *testing.T) {
@@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
} }
var responseMSG *dns.Msg var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{ responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m responseMSG = m
return nil return nil
@@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver.failsTillDeact = 0 resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100 resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &mockResponseWriter{ responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil }, WriteMsgFunc: func(m *dns.Msg) error { return nil },
} }

View File

@@ -5,7 +5,6 @@ package dns
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -18,5 +17,4 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
} }

View File

@@ -1,7 +1,6 @@
package dns package dns
import ( import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -13,6 +12,5 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error) GetInterfaceGUIDString() (string, error)
} }

View File

@@ -33,6 +33,8 @@ type DNSForwarder struct {
dnsServer *dns.Server dnsServer *dns.Server
mux *dns.ServeMux mux *dns.ServeMux
tcpServer *dns.Server
tcpMux *dns.ServeMux
mutex sync.RWMutex mutex sync.RWMutex
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
@@ -50,22 +52,41 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager
} }
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress) log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux()
dnsServer := &dns.Server{ // UDP server
mux := dns.NewServeMux()
f.mux = mux
f.dnsServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
} }
f.dnsServer = dnsServer // TCP server
f.mux = mux tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux
f.tcpServer = &dns.Server{
Addr: f.listenAddress,
Net: "tcp",
Handler: tcpMux,
}
f.UpdateDomains(entries) f.UpdateDomains(entries)
return dnsServer.ListenAndServe() errCh := make(chan error, 2)
}
go func() {
log.Infof("DNS UDP listener running on %s", f.listenAddress)
errCh <- f.dnsServer.ListenAndServe()
}()
go func() {
log.Infof("DNS TCP listener running on %s", f.listenAddress)
errCh <- f.tcpServer.ListenAndServe()
}()
// return the first error we get (e.g. bind failure or shutdown)
return <-errCh
}
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
@@ -77,31 +98,41 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
} }
oldDomains := filterDomains(f.fwdEntries) oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains { for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString()) f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
} }
newDomains := filterDomains(entries) newDomains := filterDomains(entries)
for _, d := range newDomains { for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery) f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
} }
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil { var result *multierror.Error
return nil
if f.dnsServer != nil {
if err := f.dnsServer.ShutdownContext(ctx); err != nil {
result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err))
}
} }
return f.dnsServer.ShutdownContext(ctx) if f.tcpServer != nil {
if err := f.tcpServer.ShutdownContext(ctx); err != nil {
result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err))
}
}
return nberrors.FormatErrorOrNil(result)
} }
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 { if len(query.Question) == 0 {
return return nil
} }
question := query.Question[0] question := query.Question[0]
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
@@ -123,20 +154,53 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
return return nil
} }
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, resp, domain, err) f.handleDNSError(w, query, resp, domain, err)
return return nil
} }
f.updateInternalState(domain, ips) f.updateInternalState(domain, ips)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
@@ -179,7 +243,7 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
} }
// handleDNSError processes DNS lookup errors and sends an appropriate error response // handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError var dnsErr *net.DNSError
switch { switch {
@@ -191,7 +255,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai
} }
if dnsErr.Server != "" { if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
} else { } else {
log.Warnf(errResolveFailed, domain, err) log.Warnf(errResolveFailed, domain, err)
} }

View File

@@ -33,6 +33,7 @@ type Manager struct {
statusRecorder *peer.Status statusRecorder *peer.Status
fwRules []firewall.Rule fwRules []firewall.Rule
tcpRules []firewall.Rule
dnsForwarder *DNSForwarder dnsForwarder *DNSForwarder
} }
@@ -107,6 +108,13 @@ func (m *Manager) allowDNSFirewall() error {
} }
m.fwRules = dnsRules m.fwRules = dnsRules
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
m.tcpRules = tcpRules
return nil return nil
} }
@@ -117,7 +125,13 @@ func (m *Manager) dropDNSFirewall() error {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
} }
} }
for _, rule := range m.tcpRules {
if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
m.fwRules = nil m.fwRules = nil
m.tcpRules = nil
return nberrors.FormatErrorOrNil(mErr) return nberrors.FormatErrorOrNil(mErr)
} }

View File

@@ -38,6 +38,7 @@ import (
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
@@ -122,6 +123,8 @@ type EngineConfig struct {
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
LazyConnectionEnabled bool
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -134,6 +137,8 @@ type Engine struct {
// peerConns is a map that holds all the peers that are known to this peer // peerConns is a map that holds all the peers that are known to this peer
peerStore *peerstore.Store peerStore *peerstore.Store
connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc afterPeerHook nbnet.RemoveHookFunc
@@ -170,7 +175,8 @@ type Engine struct {
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server sshServer nbssh.Server
statusRecorder *peer.Status statusRecorder *peer.Status
peerConnDispatcher *dispatcher.ConnectionDispatcher
firewall firewallManager.Manager firewall firewallManager.Manager
routeManager routemanager.Manager routeManager routemanager.Manager
@@ -262,6 +268,10 @@ func (e *Engine) Stop() error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if e.connMgr != nil {
e.connMgr.Close()
}
// stopping network monitor first to avoid starting the engine again // stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil { if e.networkMonitor != nil {
e.networkMonitor.Stop() e.networkMonitor.Stop()
@@ -297,8 +307,7 @@ func (e *Engine) Stop() error {
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
err := e.removeAllPeers() if err := e.removeAllPeers(); err != nil {
if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
} }
@@ -405,8 +414,7 @@ func (e *Engine) Start() error {
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
err = e.wgInterfaceCreate() if err = e.wgInterfaceCreate(); err != nil {
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close() e.close()
return fmt.Errorf("create wg interface: %w", err) return fmt.Errorf("create wg interface: %w", err)
@@ -442,6 +450,11 @@ func (e *Engine) Start() error {
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
} }
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start() e.srWatcher.Start()
@@ -450,7 +463,6 @@ func (e *Engine) Start() error {
// starting network monitor at the very last to avoid disruptions // starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor() e.startNetworkMonitor()
return nil return nil
} }
@@ -550,6 +562,16 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
currentPeer, ok := e.peerStore.PeerConn(peerPubKey)
if !ok {
continue
}
if currentPeer.AgentVersionString() != p.AgentVersion {
modified = append(modified, p)
continue
}
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey) allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
if !ok { if !ok {
continue continue
@@ -559,8 +581,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
continue continue
} }
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()) if err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()); err != nil {
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err) log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
} }
} }
@@ -621,16 +642,11 @@ func (e *Engine) removePeer(peerKey string) error {
e.sshServer.RemoveAuthorizedKey(peerKey) e.sshServer.RemoveAuthorizedKey(peerKey)
} }
defer func() { e.connMgr.RemovePeerConn(peerKey)
err := e.statusRecorder.RemovePeer(peerKey)
if err != nil {
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
}
}()
conn, exists := e.peerStore.Remove(peerKey) err := e.statusRecorder.RemovePeer(peerKey)
if exists { if err != nil {
conn.Close() log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
} }
return nil return nil
} }
@@ -952,12 +968,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil return nil
} }
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
log.Errorf("failed to update lazy connection feature flag: %v", err)
}
if e.firewall != nil { if e.firewall != nil {
if localipfw, ok := e.firewall.(localIpUpdater); ok { if localipfw, ok := e.firewall.(localIpUpdater); ok {
if err := localipfw.UpdateLocalIPs(); err != nil { if err := localipfw.UpdateLocalIPs(); err != nil {
log.Errorf("failed to update local IPs: %v", err) log.Errorf("failed to update local IPs: %v", err)
} }
} }
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes.
// This needs to be toggled before applying routes.
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := e.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
} }
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
@@ -976,7 +1004,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules // Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil { forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err) log.Errorf("failed to update forward rules, err: %v", err)
} }
@@ -1022,6 +1051,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(excludedLazyPeers)
protoDNSConfig := networkMap.GetDNSConfig() protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil { if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
@@ -1155,7 +1188,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
IP: strings.Join(offlinePeer.GetAllowedIps(), ","), IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
PubKey: offlinePeer.GetWgPubKey(), PubKey: offlinePeer.GetWgPubKey(),
FQDN: offlinePeer.GetFqdn(), FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected, ConnStatus: peer.StatusIdle,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
@@ -1191,12 +1224,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerIPs = append(peerIPs, allowedNetIP) peerIPs = append(peerIPs, allowedNetIP)
} }
conn, err := e.createPeerConn(peerKey, peerIPs) conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
if err != nil { if err != nil {
return fmt.Errorf("create peer connection: %w", err) return fmt.Errorf("create peer connection: %w", err)
} }
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
conn.Close() conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey) return fmt.Errorf("peer already exists: %s", peerKey)
} }
@@ -1205,17 +1243,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
conn.AddBeforeAddPeerHook(e.beforePeerHook) conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook) conn.AddAfterRemovePeerHook(e.afterPeerHook)
} }
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
conn.Open()
return nil return nil
} }
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) { func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentVersion string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey) log.Debugf("creating peer connection %s", pubKey)
wgConfig := peer.WgConfig{ wgConfig := peer.WgConfig{
@@ -1229,11 +1260,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
// randomize connection timeout // randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{ config := peer.ConnConfig{
Key: pubKey, Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(), LocalKey: e.config.WgPrivateKey.PublicKey().String(),
Timeout: timeout, AgentVersion: agentVersion,
WgConfig: wgConfig, Timeout: timeout,
LocalWgPort: e.config.WgPort, WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
RosenpassConfig: peer.RosenpassConfig{ RosenpassConfig: peer.RosenpassConfig{
PubKey: e.getRosenpassPubKey(), PubKey: e.getRosenpassPubKey(),
Addr: e.getRosenpassAddr(), Addr: e.getRosenpassAddr(),
@@ -1249,7 +1281,16 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
}, },
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore) serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
PeerConnDispatcher: e.peerConnDispatcher,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1270,7 +1311,7 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
conn, ok := e.peerStore.PeerConn(msg.Key) conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
if !ok { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)
} }
@@ -1578,13 +1619,39 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states. // and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool { func (e *Engine) RunHealthProbes() bool {
e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy() signalHealthy := e.signal.IsHealthy()
log.Debugf("signal health check: healthy=%t", signalHealthy) log.Debugf("signal health check: healthy=%t", signalHealthy)
managementHealthy := e.mgmClient.IsHealthy() managementHealthy := e.mgmClient.IsHealthy()
log.Debugf("management health check: healthy=%t", managementHealthy) log.Debugf("management health check: healthy=%t", managementHealthy)
results := append(e.probeSTUNs(), e.probeTURNs()...) stuns := slices.Clone(e.STUNs)
turns := slices.Clone(e.TURNs)
if e.wgInterface != nil {
stats, err := e.wgInterface.GetStats()
if err != nil {
log.Warnf("failed to get wireguard stats: %v", err)
e.syncMsgMux.Unlock()
return false
}
for _, key := range e.peerStore.PeersPubKey() {
// wgStats could be zero value, in which case we just reset the stats
wgStats, ok := stats[key]
if !ok {
continue
}
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
}
e.syncMsgMux.Unlock()
results := e.probeICE(stuns, turns)
e.statusRecorder.UpdateRelayStates(results) e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true relayHealthy := true
@@ -1596,37 +1663,16 @@ func (e *Engine) RunHealthProbes() bool {
} }
log.Debugf("relay health check: healthy=%t", relayHealthy) log.Debugf("relay health check: healthy=%t", relayHealthy)
for _, key := range e.peerStore.PeersPubKey() {
wgStats, err := e.wgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
continue
}
// wgStats could be zero value, in which case we just reset the stats
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
allHealthy := signalHealthy && managementHealthy && relayHealthy allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy) log.Debugf("all health checks completed: healthy=%t", allHealthy)
return allHealthy return allHealthy
} }
func (e *Engine) probeSTUNs() []relay.ProbeResult { func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
e.syncMsgMux.Lock() return append(
stuns := slices.Clone(e.STUNs) relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
e.syncMsgMux.Unlock() relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)...,
)
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
}
func (e *Engine) probeTURNs() []relay.ProbeResult {
e.syncMsgMux.Lock()
turns := slices.Clone(e.TURNs)
e.syncMsgMux.Unlock()
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
} }
// restartEngine restarts the engine by cancelling the client context // restartEngine restarts the engine by cancelling the client context
@@ -1813,21 +1859,21 @@ func (e *Engine) Address() (netip.Addr, error) {
return ip.Unmap(), nil return ip.Unmap(), nil
} }
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error { func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
if e.firewall == nil { if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules") log.Warn("firewall is disabled, not updating forwarding rules")
return nil return nil, nil
} }
if len(rules) == 0 { if len(rules) == 0 {
if e.ingressGatewayMgr == nil { if e.ingressGatewayMgr == nil {
return nil return nil, nil
} }
err := e.ingressGatewayMgr.Close() err := e.ingressGatewayMgr.Close()
e.ingressGatewayMgr = nil e.ingressGatewayMgr = nil
e.statusRecorder.SetIngressGwMgr(nil) e.statusRecorder.SetIngressGwMgr(nil)
return err return nil, err
} }
if e.ingressGatewayMgr == nil { if e.ingressGatewayMgr == nil {
@@ -1878,7 +1924,33 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
log.Errorf("failed to update forwarding rules: %v", err) log.Errorf("failed to update forwarding rules: %v", err)
} }
return nberrors.FormatErrorOrNil(merr) return forwardingRules, nberrors.FormatErrorOrNil(merr)
}
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) []string {
excludedPeers := make([]string, 0)
for _, r := range routes {
if r.Peer == "" {
continue
}
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
excludedPeers = append(excludedPeers, r.Peer)
}
for _, r := range rules {
ip := r.TranslatedAddress
for _, p := range peers {
for _, allowedIP := range p.GetAllowedIps() {
if allowedIP != ip.String() {
continue
}
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
excludedPeers = append(excludedPeers, p.GetWgPubKey())
}
}
}
return excludedPeers
} }
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.

View File

@@ -28,8 +28,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@@ -38,6 +36,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@@ -53,6 +52,7 @@ import (
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@@ -93,7 +93,7 @@ type MockWGIface struct {
GetFilterFunc func() device.PacketFilter GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetStatsFunc func() (map[string]configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error) GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net GetNetFunc func() *netstack.Net
@@ -171,8 +171,8 @@ func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
return m.GetWGDeviceFunc() return m.GetWGDeviceFunc()
} }
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (m *MockWGIface) GetStats() (map[string]configurer.WGStats, error) {
return m.GetStatsFunc(peerKey) return m.GetStatsFunc()
} }
func (m *MockWGIface) GetProxy() wgproxy.Proxy { func (m *MockWGIface) GetProxy() wgproxy.Proxy {
@@ -378,6 +378,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, },
} }
}, },
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return nil
},
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
@@ -400,6 +403,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
type testCase struct { type testCase struct {
name string name string
@@ -770,6 +775,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.routeManager = mockRouteManager engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{} engine.dnsServer = &dns.MockServer{}
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() { defer func() {
exitErr := engine.Stop() exitErr := engine.Stop()
@@ -966,6 +973,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
} }
engine.dnsServer = mockDNSServer engine.dnsServer = mockDNSServer
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() { defer func() {
exitErr := engine.Stop() exitErr := engine.Stop()
@@ -1476,7 +1485,7 @@ func getConnectedPeers(e *Engine) int {
i := 0 i := 0
for _, id := range e.peerStore.PeersPubKey() { for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id) conn, _ := e.peerStore.PeerConn(id)
if conn.Status() == peer.StatusConnected { if conn.IsConnected() {
i++ i++
} }
} }

View File

@@ -35,6 +35,6 @@ type wgIfaceBase interface {
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net GetNet() *netstack.Net
} }

View File

@@ -0,0 +1,9 @@
//go:build !linux || android
package activity
import "net"
var (
listenIP = net.IP{127, 0, 0, 1}
)

View File

@@ -0,0 +1,10 @@
//go:build !android
package activity
import "net"
var (
// use this ip to avoid eBPF proxy congestion
listenIP = net.IP{127, 0, 1, 1}
)

View File

@@ -0,0 +1,106 @@
package activity
import (
"fmt"
"net"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
wgIface lazyconn.WGIface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
done sync.Mutex
isClosed atomic.Bool // use to avoid error log when closing the listener
}
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
wgIface: wgIface,
peerCfg: cfg,
}
conn, err := d.newConn()
if err != nil {
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
}
d.conn = conn
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
if err := d.createEndpoint(); err != nil {
return nil, err
}
d.done.Lock()
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
return d, nil
}
func (d *Listener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Debugf("exit from activity listener")
} else {
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
}
break
}
if n < 1 {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
break
}
if err := d.removeEndpoint(); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
d.done.Unlock()
}
func (d *Listener) Close() {
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
if err := d.conn.Close(); err != nil {
d.peerCfg.Log.Errorf("failed to close UDP listener: %s", err)
}
d.done.Lock()
}
func (d *Listener) removeEndpoint() error {
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}
func (d *Listener) createEndpoint() error {
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
}
func (d *Listener) newConn() (*net.UDPConn, error) {
addr := &net.UDPAddr{
Port: 0,
IP: listenIP,
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
log.Errorf("failed to create activity listener on %s: %s", addr, err)
return nil, err
}
return conn, nil
}

View File

@@ -0,0 +1,41 @@
package activity
import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestNewListener(t *testing.T) {
peer := &MocPeer{
PeerID: "examplePublicKey1",
}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
l, err := NewListener(MocWGIface{}, cfg)
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
chanClosed := make(chan struct{})
go func() {
defer close(chanClosed)
l.ReadPackets()
}()
time.Sleep(1 * time.Second)
l.Close()
select {
case <-chanClosed:
case <-time.After(time.Second):
}
}

View File

@@ -0,0 +1,95 @@
package activity
import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type Manager struct {
OnActivityChan chan peerid.ConnID
wgIface lazyconn.WGIface
peers map[peerid.ConnID]*Listener
done chan struct{}
mu sync.Mutex
}
func NewManager(wgIface lazyconn.WGIface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]*Listener),
done: make(chan struct{}),
}
return m
}
func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.peers[peerCfg.PeerConnID]; ok {
log.Warnf("activity listener already exists for: %s", peerCfg.PublicKey)
return nil
}
listener, err := NewListener(m.wgIface, peerCfg)
if err != nil {
return err
}
m.peers[peerCfg.PeerConnID] = listener
go m.waitForTraffic(listener, peerCfg.PeerConnID)
return nil
}
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
m.mu.Lock()
defer m.mu.Unlock()
listener, ok := m.peers[peerConnID]
if !ok {
return
}
log.Debugf("removing activity listener")
delete(m.peers, peerConnID)
listener.Close()
}
func (m *Manager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
close(m.done)
for peerID, listener := range m.peers {
delete(m.peers, peerID)
listener.Close()
}
}
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
listener.ReadPackets()
m.mu.Lock()
if _, ok := m.peers[peerConnID]; !ok {
m.mu.Unlock()
return
}
delete(m.peers, peerConnID)
m.mu.Unlock()
m.notify(peerConnID)
}
func (m *Manager) notify(peerConnID peerid.ConnID) {
select {
case <-m.done:
case m.OnActivityChan <- peerConnID:
}
}

View File

@@ -0,0 +1,162 @@
package activity
import (
"net"
"net/netip"
"testing"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
PeerID string
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
type MocWGIface struct {
}
func (m MocWGIface) RemovePeer(string) error {
return nil
}
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
func TestManager_MonitorPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
select {
case peerConnID := <-mgr.OnActivityChan:
if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", peerConnID)
}
case <-time.After(1 * time.Second):
}
}
func TestManager_RemovePeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
if err := trigger(addr); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
select {
case <-mgr.OnActivityChan:
t.Fatal("should not have active activity")
case <-time.After(1 * time.Second):
}
}
func TestManager_MultiPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
peer2 := &MocPeer{}
peerCfg2 := lazyconn.PeerConfig{
PublicKey: peer2.PeerID,
PeerConnID: peer2.ConnID(),
Log: log.WithField("peer", "examplePublicKey2"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := mgr.MonitorPeerActivity(peerCfg2); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
for i := 0; i < 2; i++ {
select {
case <-mgr.OnActivityChan:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for activity")
}
}
}
func trigger(addr string) error {
// Create a connection to the destination UDP address and port
conn, err := net.Dial("udp", addr)
if err != nil {
return err
}
defer conn.Close()
// Write the bytes to the UDP connection
_, err = conn.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05})
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,32 @@
/*
Package lazyconn provides mechanisms for managing lazy connections, which activate on demand to optimize resource usage and establish connections efficiently.
## Overview
The package includes a `Manager` component responsible for:
- Managing lazy connections activated on-demand
- Managing inactivity monitors for lazy connections (based on peer disconnection events)
- Maintaining a list of excluded peers that should always have permanent connections
- Handling remote peer connection initiatives based on peer signaling
## Thread-Safe Operations
The `Manager` ensures thread safety across multiple operations, categorized by caller:
- **Engine (single goroutine)**:
- `AddPeer`: Adds a peer to the connection manager.
- `RemovePeer`: Removes a peer from the connection manager.
- `ActivatePeer`: Activates a lazy connection for a peer. This come from Signal client
- `ExcludePeer`: Marks peers for a permanent connection. Like router peers and other peers that should always have a connection.
- **Connection Dispatcher (any peer routine)**:
- `onPeerConnected`: Suspend the inactivity monitor for an active peer connection.
- `onPeerDisconnected`: Starts the inactivity monitor for a disconnected peer.
- **Activity Manager**:
- `onPeerActivity`: Run peer.Open(context).
- **Inactivity Monitor**:
- `onPeerInactivityTimedOut`: Close peer connection and restart activity monitor.
*/
package lazyconn

View File

@@ -0,0 +1,26 @@
package lazyconn
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
)
func IsLazyConnEnabledByEnv() bool {
val := os.Getenv(EnvEnableLazyConn)
if val == "" {
return false
}
enabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
return false
}
return enabled
}

View File

@@ -0,0 +1,70 @@
package inactivity
import (
"context"
"time"
peer "github.com/netbirdio/netbird/client/internal/peer/id"
)
const (
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
MinimumInactivityThreshold = 3 * time.Minute
)
type Monitor struct {
id peer.ConnID
timer *time.Timer
cancel context.CancelFunc
inactivityThreshold time.Duration
}
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
i := &Monitor{
id: peerID,
timer: time.NewTimer(0),
inactivityThreshold: threshold,
}
i.timer.Stop()
return i
}
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
i.timer.Reset(i.inactivityThreshold)
defer i.timer.Stop()
ctx, i.cancel = context.WithCancel(ctx)
defer func() {
defer i.cancel()
select {
case <-i.timer.C:
default:
}
}()
select {
case <-i.timer.C:
select {
case timeoutChan <- i.id:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
func (i *Monitor) Stop() {
if i.cancel == nil {
return
}
i.cancel()
}
func (i *Monitor) PauseTimer() {
i.timer.Stop()
}
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}

View File

@@ -0,0 +1,156 @@
package inactivity
import (
"context"
"testing"
"time"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
func TestInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestReuseInactivityMonitor(t *testing.T) {
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
for i := 2; i > 0; i-- {
exitChan := make(chan struct{})
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
go func() {
defer close(exitChan)
im.Start(testTimeoutCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
testTimeoutCancel()
}
}
func TestStopInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
go func() {
time.Sleep(3 * time.Second)
im.Stop()
}()
select {
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestPauseInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
defer testTimeoutCancel()
p := &MocPeer{}
trashHold := time.Second * 3
im := NewInactivityMonitor(p.ConnID(), trashHold)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(ctx, timeoutChan)
}()
time.Sleep(1 * time.Second) // grant time to start the monitor
im.PauseTimer()
// check to do not receive timeout
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
defer thresholdCancel()
select {
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-thresholdCtx.Done():
// test ok
case <-tCtx.Done():
t.Fatal("test timed out")
}
// test reset timer
im.ResetTimer()
select {
case <-tCtx.Done():
t.Fatal("test timed out")
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
// expected timeout
}
}

View File

@@ -0,0 +1,404 @@
package manager
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
const (
watcherActivity watcherType = iota
watcherInactivity
)
type watcherType int
type managedPeer struct {
peerCfg *lazyconn.PeerConfig
expectedWatcher watcherType
}
type Config struct {
InactivityThreshold *time.Duration
}
// Manager manages lazy connections
// It is responsible for:
// - Managing lazy connections activated on-demand
// - Managing inactivity monitors for lazy connections (based on peer disconnection events)
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
type Manager struct {
peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration
connStateListener *dispatcher.ConnectionListener
managedPeers map[string]*lazyconn.PeerConfig
managedPeersByConnID map[peerid.ConnID]*managedPeer
excludes map[string]lazyconn.PeerConfig
managedPeersMu sync.Mutex
activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
cancel context.CancelFunc
onInactive chan peerid.ConnID
}
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
log.Infof("setup lazy connection service")
m := &Manager{
peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold,
managedPeers: make(map[string]*lazyconn.PeerConfig),
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
onInactive: make(chan peerid.ConnID),
}
if config.InactivityThreshold != nil {
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
m.inactivityThreshold = *config.InactivityThreshold
} else {
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
}
}
m.connStateListener = &dispatcher.ConnectionListener{
OnConnected: m.onPeerConnected,
OnDisconnected: m.onPeerDisconnected,
}
connStateDispatcher.AddListener(m.connStateListener)
return m
}
// Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) {
defer m.close()
ctx, m.cancel = context.WithCancel(ctx)
for {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID)
}
}
}
// ExcludePeer marks peers for a permanent connection
// It removes peers from the managed list if they are added to the exclude list
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
// this case, we suppose that the connection status is connected or connecting.
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
added := make([]string, 0)
excludes := make(map[string]lazyconn.PeerConfig, len(peerConfigs))
for _, peerCfg := range peerConfigs {
log.Infof("update excluded lazy connection list with peer: %s", peerCfg.PublicKey)
excludes[peerCfg.PublicKey] = peerCfg
}
// if a peer is newly added to the exclude list, remove from the managed peers list
for pubKey, peerCfg := range excludes {
if _, wasExcluded := m.excludes[pubKey]; wasExcluded {
continue
}
added = append(added, pubKey)
peerCfg.Log.Infof("peer newly added to lazy connection exclude list")
m.removePeer(pubKey)
}
// if a peer has been removed from exclude list then it should be added to the managed peers
for pubKey, peerCfg := range m.excludes {
if _, stillExcluded := excludes[pubKey]; stillExcluded {
continue
}
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
if err := m.addActivePeer(ctx, peerCfg); err != nil {
log.Errorf("failed to add peer to lazy connection manager: %s", err)
continue
}
}
m.excludes = excludes
return added
}
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
peerCfg.Log.Debugf("adding peer to lazy connection manager")
_, exists := m.excludes[peerCfg.PublicKey]
if exists {
return true, nil
}
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return false, nil
}
if err := m.activityManager.MonitorPeerActivity(peerCfg); err != nil {
return false, err
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherActivity,
}
return false, nil
}
// AddActivePeers adds a list of peers to the lazy connection manager
// suppose these peers was in connected or in connecting states
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
for _, cfg := range peerCfg {
if _, ok := m.managedPeers[cfg.PublicKey]; ok {
cfg.Log.Errorf("peer already managed")
continue
}
if err := m.addActivePeer(ctx, cfg); err != nil {
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
return err
}
}
return nil
}
func (m *Manager) RemovePeer(peerID string) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.removePeer(peerID)
}
// ActivatePeer activates a peer connection when a signal message is received
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
cfg, ok := m.managedPeers[peerID]
if !ok {
return false
}
mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
return false
}
// signal messages coming continuously after success activation, with this avoid the multiple activation
if mp.expectedWatcher == watcherInactivity {
return false
}
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
im, ok := m.inactivityMonitors[cfg.PeerConnID]
if !ok {
cfg.Log.Errorf("inactivity monitor not found for peer")
return false
}
mp.peerCfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive)
return true
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return nil
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherInactivity,
}
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
go im.Start(ctx, m.onInactive)
return nil
}
func (m *Manager) removePeer(peerID string) {
cfg, ok := m.managedPeers[peerID]
if !ok {
return
}
cfg.Log.Infof("removing lazy peer")
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
im.Stop()
delete(m.inactivityMonitors, cfg.PeerConnID)
cfg.Log.Debugf("inactivity monitor stopped")
}
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
delete(m.managedPeers, peerID)
delete(m.managedPeersByConnID, cfg.PeerConnID)
}
func (m *Manager) close() {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.cancel()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
iw.Stop()
}
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
log.Infof("lazy connection manager closed")
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherActivity {
mp.peerCfg.Log.Warnf("ignore activity event")
return
}
mp.peerCfg.Log.Infof("detected peer activity")
mp.expectedWatcher = watcherInactivity
mp.peerCfg.Log.Infof("starting inactivity monitor")
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
return
}
mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity
// just in case free up
m.inactivityMonitors[peerConnID].PauseTimer()
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
return
}
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
iw.PauseTimer()
}
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
return
}
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
iw.ResetTimer()
}

View File

@@ -0,0 +1,16 @@
package lazyconn
import (
"net/netip"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/id"
)
type PeerConfig struct {
PublicKey string
AllowedIPs []netip.Prefix
PeerConnID id.ConnID
Log *log.Entry
}

View File

@@ -0,0 +1,41 @@
package lazyconn
import (
"strings"
"github.com/hashicorp/go-version"
)
var (
minVersion = version.Must(version.NewVersion("0.45.0"))
)
func IsSupported(agentVersion string) bool {
if agentVersion == "development" {
return true
}
// filter out versions like this: a6c5960, a7d5c522, d47be154
if !strings.Contains(agentVersion, ".") {
return false
}
normalizedVersion := normalizeVersion(agentVersion)
inputVer, err := version.NewVersion(normalizedVersion)
if err != nil {
return false
}
return inputVer.GreaterThanOrEqual(minVersion)
}
func normalizeVersion(version string) string {
// Remove prefixes like 'v' or 'a'
if len(version) > 0 && (version[0] == 'v' || version[0] == 'a') {
version = version[1:]
}
// Remove any suffixes like '-dirty', '-dev', '-SNAPSHOT', etc.
parts := strings.Split(version, "-")
return parts[0]
}

View File

@@ -0,0 +1,31 @@
package lazyconn
import "testing"
func TestIsSupported(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"development", true},
{"0.45.0", true},
{"v0.45.0", true},
{"0.45.1", true},
{"0.45.1-SNAPSHOT-559e6731", true},
{"v0.45.1-dev", true},
{"a7d5c522", false},
{"0.9.6", false},
{"0.9.6-SNAPSHOT", false},
{"0.9.6-SNAPSHOT-2033650", false},
{"meta_wt_version", false},
{"v0.31.1-dev", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
if got := IsSupported(tt.version); got != tt.want {
t.Errorf("IsSupported() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,14 @@
package lazyconn
import (
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
}

View File

@@ -123,8 +123,14 @@ func (m *Manager) disableFlow() error {
m.logger.Close() m.logger.Close()
if m.receiverClient != nil { if m.receiverClient == nil {
return m.receiverClient.Close() return nil
}
err := m.receiverClient.Close()
m.receiverClient = nil
if err != nil {
return fmt.Errorf("close: %w", err)
} }
return nil return nil

View File

@@ -19,7 +19,7 @@ import (
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("open routing socket: %v", err)
} }
defer func() { defer func() {
err := unix.Close(fd) err := unix.Close(fd)

View File

@@ -13,7 +13,7 @@ import (
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
routeMonitor, err := systemops.NewRouteMonitor(ctx) routeMonitor, err := systemops.NewRouteMonitor(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create route monitor: %w", err) return fmt.Errorf("create route monitor: %w", err)
} }
defer func() { defer func() {
if err := routeMonitor.Stop(); err != nil { if err := routeMonitor.Stop(); err != nil {
@@ -38,35 +38,49 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
} }
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool { func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
intf := "<nil>" if intf := route.NextHop.Intf; intf != nil && isSoftInterface(intf.Name) {
if route.Interface != nil { log.Debugf("Network monitor: ignoring default route change for next hop with soft interface %s", route.NextHop)
intf = route.Interface.Name return false
if isSoftInterface(intf) { }
log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf)
return false // TODO: for the empty nexthop ip (on-link), determine the family differently
} nexthop := nexthopv4
if route.NextHop.IP.Is6() {
nexthop = nexthopv6
} }
switch route.Type { switch route.Type {
case systemops.RouteModified: case systemops.RouteModified, systemops.RouteAdded:
// TODO: get routing table to figure out if our route is affected for modified routes return handleRouteAddedOrModified(route, nexthop)
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
return true
case systemops.RouteAdded:
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
return true
}
case systemops.RouteDeleted: case systemops.RouteDeleted:
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { return handleRouteDeleted(route, nexthop)
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
return true
}
} }
return false return false
} }
func handleRouteAddedOrModified(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool {
// For added/modified routes, we care about different next hops
if !nexthop.Equal(route.NextHop) {
action := "changed"
if route.Type == systemops.RouteAdded {
action = "added"
}
log.Infof("Network monitor: default route %s: via %s", action, route.NextHop)
return true
}
return false
}
func handleRouteDeleted(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool {
// For deleted routes, we care about our tracked next hop being deleted
if nexthop.Equal(route.NextHop) {
log.Infof("Network monitor: default route removed: via %s", route.NextHop)
return true
}
return false
}
func isSoftInterface(name string) bool { func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo") return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
} }

View File

@@ -0,0 +1,404 @@
package networkmonitor
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func TestRouteChanged(t *testing.T) {
tests := []struct {
name string
route systemops.RouteUpdate
nexthopv4 systemops.Nexthop
nexthopv6 systemops.Nexthop
expected bool
}{
{
name: "soft interface should be ignored",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Name: "ISATAP-Interface", // isSoftInterface checks name
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.2"),
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: false,
},
{
name: "modified route with different v4 nexthop IP should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.2"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: true,
},
{
name: "modified route with same v4 nexthop (IP and Intf Index) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: false,
},
{
name: "added route with different v6 nexthop IP should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteAdded,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::2"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
expected: true,
},
{
name: "added route with same v6 nexthop (IP and Intf Index) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteAdded,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
expected: false,
},
{
name: "deleted route matching tracked v4 nexthop (IP and Intf Index) should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: true,
},
{
name: "deleted route not matching tracked v4 nexthop (different IP) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.3"), // Different IP
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: false,
},
{
name: "modified v4 route with same IP, different Intf Index should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "modified v4 route with same IP, one Intf nil, other non-nil should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: nil, // Intf is nil
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"}, // Tracked Intf is not nil
},
expected: true,
},
{
name: "added v4 route with same IP, different Intf Index should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteAdded,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "deleted v4 route with same IP, different Intf Index should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{ // This is the route being deleted
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv4: systemops.Nexthop{ // This is our tracked nexthop
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
expected: false, // Because nexthopv4.Equal(route.NextHop) will be false
},
{
name: "modified v6 route with different IP, same Intf Index should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::3"), // Different IP
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "modified v6 route with same IP, different Intf Index should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "modified v6 route with same IP, same Intf Index should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: false,
},
{
name: "deleted v6 route matching tracked nexthop (IP and Intf Index) should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "deleted v6 route not matching tracked nexthop (different IP) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::3"), // Different IP
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: false,
},
{
name: "deleted v6 route not matching tracked nexthop (same IP, different Intf Index) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{ // This is the route being deleted
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{ // This is our tracked nexthop
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
expected: false,
},
{
name: "unknown route type should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteUpdateType(99), // Unknown type
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.2"), // Different from route.NextHop
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := routeChanged(tt.route, tt.nexthopv4, tt.nexthopv6)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsSoftInterface(t *testing.T) {
tests := []struct {
name string
ifname string
expected bool
}{
{
name: "ISATAP interface should be detected",
ifname: "ISATAP tunnel adapter",
expected: true,
},
{
name: "lowercase soft interface should be detected",
ifname: "isatap.{14A5CF17-CA72-43EC-B4EA-B4B093641B7D}",
expected: true,
},
{
name: "Teredo interface should be detected",
ifname: "Teredo Tunneling Pseudo-Interface",
expected: true,
},
{
name: "regular interface should not be detected as soft",
ifname: "eth0",
expected: false,
},
{
name: "another regular interface should not be detected as soft",
ifname: "wlan0",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isSoftInterface(tt.ifname)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -118,9 +118,12 @@ func (nw *NetworkMonitor) Stop() {
} }
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) { func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
defer close(event)
for { for {
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil { if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
close(event) if !errors.Is(err, context.Canceled) {
log.Errorf("Network monitor: failed to check for changes: %v", err)
}
return return
} }
// prevent blocking // prevent blocking

View File

@@ -17,8 +17,12 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -26,32 +30,20 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case connPriorityNone:
return "None"
case connPriorityRelay:
return "PriorityRelay"
case connPriorityICETurn:
return "PriorityICETurn"
case connPriorityICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}
const ( const (
defaultWgKeepAlive = 25 * time.Second defaultWgKeepAlive = 25 * time.Second
connPriorityNone ConnPriority = 0
connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 2
connPriorityICEP2P ConnPriority = 3
) )
type ServiceDependencies struct {
StatusRecorder *Status
Signaler *Signaler
IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher
}
type WgConfig struct { type WgConfig struct {
WgListenPort int WgListenPort int
RemoteKey string RemoteKey string
@@ -76,6 +68,8 @@ type ConnConfig struct {
// LocalKey is a public key of a local peer // LocalKey is a public key of a local peer
LocalKey string LocalKey string
AgentVersion string
Timeout time.Duration Timeout time.Duration
WgConfig WgConfig WgConfig WgConfig
@@ -89,22 +83,23 @@ type ConnConfig struct {
} }
type Conn struct { type Conn struct {
log *log.Entry Log *log.Entry
mu sync.Mutex mu sync.Mutex
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
config ConnConfig config ConnConfig
statusRecorder *Status statusRecorder *Status
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager relayManager *relayClient.Manager
handshaker *Handshaker srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string) onDisconnected func(remotePeer string)
statusRelay *AtomicConnStatus statusRelay *worker.AtomicWorkerStatus
statusICE *AtomicConnStatus statusICE *worker.AtomicWorkerStatus
currentConnPriority ConnPriority currentConnPriority conntype.ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection opened bool // this flag is used to prevent close in case of not opened connection
workerICE *WorkerICE workerICE *WorkerICE
@@ -120,9 +115,12 @@ type Conn struct {
wgProxyICE wgproxy.Proxy wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup semaphore *semaphoregroup.SemaphoreGroup
peerConnDispatcher *dispatcher.ConnectionDispatcher
wg sync.WaitGroup
// debug purpose // debug purpose
dumpState *stateDump dumpState *stateDump
@@ -130,91 +128,101 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
if len(config.WgConfig.AllowedIps) == 0 { if len(config.WgConfig.AllowedIps) == 0 {
return nil, fmt.Errorf("allowed IPs is empty") return nil, fmt.Errorf("allowed IPs is empty")
} }
ctx, ctxCancel := context.WithCancel(engineCtx)
connLog := log.WithField("peer", config.Key) connLog := log.WithField("peer", config.Key)
var conn = &Conn{ var conn = &Conn{
log: connLog, Log: connLog,
ctx: ctx, config: config,
ctxCancel: ctxCancel, statusRecorder: services.StatusRecorder,
config: config, signaler: services.Signaler,
statusRecorder: statusRecorder, iFaceDiscover: services.IFaceDiscover,
signaler: signaler, relayManager: services.RelayManager,
relayManager: relayManager, srWatcher: services.SrWatcher,
statusRelay: NewAtomicConnStatus(), semaphore: services.Semaphore,
statusICE: NewAtomicConnStatus(), peerConnDispatcher: services.PeerConnDispatcher,
semaphore: semaphore, statusRelay: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, statusRecorder), statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
} }
ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil {
return nil, err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
go conn.handshaker.Listen()
go conn.dumpState.Start(ctx)
return conn, nil return conn, nil
} }
// Open opens connection to the remote peer // Open opens connection to the remote peer
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used. // be used.
func (conn *Conn) Open() { func (conn *Conn) Open(engineCtx context.Context) error {
conn.semaphore.Add(conn.ctx) conn.semaphore.Add(engineCtx)
conn.log.Debugf("open connection to peer")
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.opened = true
if conn.opened {
conn.semaphore.Done(engineCtx)
return nil
}
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.handshaker.Listen(conn.ctx)
}()
go conn.dumpState.Start(conn.ctx)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
ConnStatus: StatusDisconnected, ConnStatus: StatusConnecting,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
err := conn.statusRecorder.UpdatePeerState(peerState) if err := conn.statusRecorder.UpdatePeerState(peerState); err != nil {
if err != nil { conn.Log.Warnf("error while updating the state err: %v", err)
conn.log.Warnf("error while updating the state err: %v", err)
} }
go conn.startHandshakeAndReconnect(conn.ctx) conn.wg.Add(1)
} go func() {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx)
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { conn.dumpState.SendOffer()
defer conn.semaphore.Done(conn.ctx) if err := conn.handshaker.sendOffer(); err != nil {
conn.waitInitialRandomSleepTime(ctx) conn.Log.Errorf("failed to send initial offer: %v", err)
}
conn.dumpState.SendOffer() conn.wg.Add(1)
err := conn.handshaker.sendOffer() go func() {
if err != nil { conn.guard.Start(conn.ctx, conn.onGuardEvent)
conn.log.Errorf("failed to send initial offer: %v", err) conn.wg.Done()
} }()
}()
go conn.guard.Start(ctx) conn.opened = true
go conn.listenGuardEvent(ctx) return nil
} }
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
@@ -223,14 +231,14 @@ func (conn *Conn) Close() {
defer conn.wgWatcherWg.Wait() defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.log.Infof("close peer connection")
conn.ctxCancel()
if !conn.opened { if !conn.opened {
conn.log.Debugf("ignore close connection to peer") conn.Log.Debugf("ignore close connection to peer")
return return
} }
conn.Log.Infof("close peer connection")
conn.ctxCancel()
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
conn.workerRelay.CloseConn() conn.workerRelay.CloseConn()
conn.workerICE.Close() conn.workerICE.Close()
@@ -238,7 +246,7 @@ func (conn *Conn) Close() {
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn() err := conn.wgProxyRelay.CloseConn()
if err != nil { if err != nil {
conn.log.Errorf("failed to close wg proxy for relay: %v", err) conn.Log.Errorf("failed to close wg proxy for relay: %v", err)
} }
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
@@ -246,13 +254,13 @@ func (conn *Conn) Close() {
if conn.wgProxyICE != nil { if conn.wgProxyICE != nil {
err := conn.wgProxyICE.CloseConn() err := conn.wgProxyICE.CloseConn()
if err != nil { if err != nil {
conn.log.Errorf("failed to close wg proxy for ice: %v", err) conn.Log.Errorf("failed to close wg proxy for ice: %v", err)
} }
conn.wgProxyICE = nil conn.wgProxyICE = nil
} }
if err := conn.removeWgPeer(); err != nil { if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.freeUpConnID() conn.freeUpConnID()
@@ -262,14 +270,16 @@ func (conn *Conn) Close() {
} }
conn.setStatusToDisconnected() conn.setStatusToDisconnected()
conn.log.Infof("peer connection has been closed") conn.opened = false
conn.wg.Wait()
conn.Log.Infof("peer connection closed")
} }
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
conn.dumpState.RemoteAnswer() conn.dumpState.RemoteAnswer()
conn.log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer) return conn.handshaker.OnRemoteAnswer(answer)
} }
@@ -298,7 +308,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
conn.dumpState.RemoteOffer() conn.dumpState.RemoteOffer()
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer) return conn.handshaker.OnRemoteOffer(offer)
} }
@@ -307,19 +317,24 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig return conn.config.WgConfig
} }
// Status returns current status of the Conn // IsConnected unit tests only
func (conn *Conn) Status() ConnStatus { // refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
func (conn *Conn) IsConnected() bool {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
return conn.evalStatus() return conn.currentConnPriority != conntype.None
} }
func (conn *Conn) GetKey() string { func (conn *Conn) GetKey() string {
return conn.config.Key return conn.config.Key
} }
func (conn *Conn) ConnID() id.ConnID {
return id.ConnID(conn)
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -327,21 +342,21 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
return return
} }
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) { if remoteConnNil(conn.Log, iceConnInfo.RemoteConn) {
conn.log.Errorf("remote ICE connection is nil") conn.Log.Errorf("remote ICE connection is nil")
return return
} }
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade // this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
// todo consider to remove this check // todo consider to remove this check
if conn.currentConnPriority > priority { if conn.currentConnPriority > priority {
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority) conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
conn.statusICE.Set(StatusConnected) conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
return return
} }
conn.log.Infof("set ICE to active connection") conn.Log.Infof("set ICE to active connection")
conn.dumpState.P2PConnected() conn.dumpState.P2PConnected()
var ( var (
@@ -353,7 +368,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return return
} }
ep = wgProxy.EndpointAddr() ep = wgProxy.EndpointAddr()
@@ -369,7 +384,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
} }
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err) conn.Log.Errorf("Before add peer hook failed: %v", err)
} }
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
@@ -388,10 +403,16 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
oldState := conn.currentConnPriority
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected) conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
if oldState == conntype.None {
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
} }
func (conn *Conn) onICEStateDisconnected() { func (conn *Conn) onICEStateDisconnected() {
@@ -402,22 +423,22 @@ func (conn *Conn) onICEStateDisconnected() {
return return
} }
conn.log.Tracef("ICE connection state changed to disconnected") conn.Log.Tracef("ICE connection state changed to disconnected")
if conn.wgProxyICE != nil { if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil { if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
} }
} }
// switch back to relay connection // switch back to relay connection
if conn.isReadyToUpgrade() { if conn.isReadyToUpgrade() {
conn.log.Infof("ICE disconnected, set Relay to active connection") conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay() conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.Log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.wgWatcherWg.Add(1) conn.wgWatcherWg.Add(1)
@@ -425,17 +446,18 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done() defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
}() }()
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = conntype.Relay
} else { } else {
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = connPriorityNone conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
} }
changed := conn.statusICE.Get() != StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
if changed { if changed {
conn.guard.SetICEConnDisconnected() conn.guard.SetICEConnDisconnected()
} }
conn.statusICE.Set(StatusDisconnected) conn.statusICE.SetDisconnected()
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -446,7 +468,7 @@ func (conn *Conn) onICEStateDisconnected() {
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState) err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err) conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
} }
} }
@@ -456,41 +478,41 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil { if err := rci.relayedConn.Close(); err != nil {
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) conn.Log.Warnf("failed to close unnecessary relayed connection: %v", err)
} }
return return
} }
conn.dumpState.RelayConnected() conn.dumpState.RelayConnected()
conn.log.Debugf("Relay connection has been established, setup the WireGuard") conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
wgProxy, err := conn.newProxy(rci.relayedConn) wgProxy, err := conn.newProxy(rci.relayedConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.isICEActive() { if conn.isICEActive() {
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected) conn.statusRelay.SetConnected()
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return return
} }
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err) conn.Log.Errorf("Before add peer hook failed: %v", err)
} }
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err) conn.Log.Warnf("Failed to close relay connection: %v", err)
} }
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return return
} }
@@ -502,12 +524,13 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround() wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = conntype.Relay
conn.statusRelay.Set(StatusConnected) conn.statusRelay.SetConnected()
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay") conn.Log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
} }
func (conn *Conn) onRelayDisconnected() { func (conn *Conn) onRelayDisconnected() {
@@ -518,14 +541,15 @@ func (conn *Conn) onRelayDisconnected() {
return return
} }
conn.log.Infof("relay connection is disconnected") conn.Log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay { if conn.currentConnPriority == conntype.Relay {
conn.log.Infof("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil { if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.currentConnPriority = connPriorityNone conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@@ -533,11 +557,11 @@ func (conn *Conn) onRelayDisconnected() {
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != worker.StatusDisconnected
if changed { if changed {
conn.guard.SetRelayedConnDisconnected() conn.guard.SetRelayedConnDisconnected()
} }
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.SetDisconnected()
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -546,22 +570,15 @@ func (conn *Conn) onRelayDisconnected() {
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
} }
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) conn.Log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
} }
} }
func (conn *Conn) listenGuardEvent(ctx context.Context) { func (conn *Conn) onGuardEvent() {
for { conn.Log.Debugf("send offer to peer")
select { conn.dumpState.SendOffer()
case <-conn.guard.Reconnect: if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Infof("send offer to peer") conn.Log.Errorf("failed to send offer: %v", err)
conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err)
}
case <-ctx.Done():
return
}
} }
} }
@@ -588,7 +605,7 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
err := conn.statusRecorder.UpdatePeerRelayedState(peerState) err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to save peer's Relay state, got error: %v", err) conn.Log.Warnf("unable to save peer's Relay state, got error: %v", err)
} }
} }
@@ -607,17 +624,18 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
err := conn.statusRecorder.UpdatePeerICEState(peerState) err := conn.statusRecorder.UpdatePeerICEState(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to save peer's ICE state, got error: %v", err) conn.Log.Warnf("unable to save peer's ICE state, got error: %v", err)
} }
} }
func (conn *Conn) setStatusToDisconnected() { func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.SetDisconnected()
conn.statusICE.Set(StatusDisconnected) conn.statusICE.SetDisconnected()
conn.currentConnPriority = conntype.None
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: StatusDisconnected, ConnStatus: StatusIdle,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
@@ -625,10 +643,10 @@ func (conn *Conn) setStatusToDisconnected() {
if err != nil { if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available. // pretty common error because by that time Engine can already remove the peer and status won't be available.
// todo rethink status updates // todo rethink status updates
conn.log.Debugf("error while updating peer's state, err: %v", err) conn.Log.Debugf("error while updating peer's state, err: %v", err)
} }
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil { if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err) conn.Log.Debugf("failed to reset wireguard stats for peer: %s", err)
} }
} }
@@ -656,27 +674,20 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
} }
func (conn *Conn) isRelayed() bool { func (conn *Conn) isRelayed() bool {
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) { switch conn.currentConnPriority {
case conntype.Relay, conntype.ICETurn:
return true
default:
return false return false
} }
if conn.currentConnPriority == connPriorityICEP2P {
return false
}
return true
} }
func (conn *Conn) evalStatus() ConnStatus { func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected { if conn.statusRelay.Get() == worker.StatusConnected || conn.statusICE.Get() == worker.StatusConnected {
return StatusConnected return StatusConnected
} }
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting { return StatusConnecting
return StatusConnecting
}
return StatusDisconnected
} }
func (conn *Conn) isConnectedOnAllWay() (connected bool) { func (conn *Conn) isConnectedOnAllWay() (connected bool) {
@@ -689,12 +700,12 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
} }
}() }()
if conn.statusICE.Get() == StatusDisconnected { if conn.statusICE.Get() == worker.StatusDisconnected {
return false return false
} }
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() != StatusConnected { if conn.statusRelay.Get() == worker.StatusDisconnected {
return false return false
} }
} }
@@ -716,7 +727,7 @@ func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" { if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil { if err := hook(conn.connIDRelay); err != nil {
conn.log.Errorf("After remove peer hook failed: %v", err) conn.Log.Errorf("After remove peer hook failed: %v", err)
} }
} }
conn.connIDRelay = "" conn.connIDRelay = ""
@@ -725,7 +736,7 @@ func (conn *Conn) freeUpConnID() {
if conn.connIDICE != "" { if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil { if err := hook(conn.connIDICE); err != nil {
conn.log.Errorf("After remove peer hook failed: %v", err) conn.Log.Errorf("After remove peer hook failed: %v", err)
} }
} }
conn.connIDICE = "" conn.connIDICE = ""
@@ -733,7 +744,7 @@ func (conn *Conn) freeUpConnID() {
} }
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(), IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
Port: conn.config.WgConfig.WgListenPort, Port: conn.config.WgConfig.WgListenPort,
@@ -741,18 +752,18 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
wgProxy := conn.config.WgConfig.WgInterface.GetProxy() wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil { if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return nil, err return nil, err
} }
return wgProxy, nil return wgProxy, nil
} }
func (conn *Conn) isReadyToUpgrade() bool { func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
} }
func (conn *Conn) isICEActive() bool { func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
} }
func (conn *Conn) removeWgPeer() error { func (conn *Conn) removeWgPeer() error {
@@ -760,10 +771,10 @@ func (conn *Conn) removeWgPeer() error {
} }
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err) conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil { if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil { if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr) conn.Log.Warnf("Failed to close wg proxy: %v", ierr)
} }
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@@ -773,16 +784,16 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
func (conn *Conn) logTraceConnState() { func (conn *Conn) logTraceConnState() {
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) conn.Log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
} else { } else {
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE) conn.Log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
} }
} }
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil { if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
} }
} }
conn.wgProxyRelay = proxy conn.wgProxyRelay = proxy
@@ -793,6 +804,10 @@ func (conn *Conn) AllowedIP() netip.Addr {
return conn.config.WgConfig.AllowedIps[0].Addr() return conn.config.WgConfig.AllowedIps[0].Addr()
} }
func (conn *Conn) AgentVersionString() string {
return conn.config.AgentVersion
}
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key { func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
if conn.config.RosenpassConfig.PubKey == nil { if conn.config.RosenpassConfig.PubKey == nil {
return conn.config.WgConfig.PreSharedKey return conn.config.WgConfig.PreSharedKey
@@ -804,7 +819,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
determKey, err := conn.rosenpassDetermKey() determKey, err := conn.rosenpassDetermKey()
if err != nil { if err != nil {
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err) conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return conn.config.WgConfig.PreSharedKey return conn.config.WgConfig.PreSharedKey
} }

View File

@@ -1,58 +1,29 @@
package peer package peer
import ( import (
"sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
// StatusConnected indicate the peer is in connected state // StatusIdle indicate the peer is in disconnected state
StatusConnected ConnStatus = iota StatusIdle ConnStatus = iota
// StatusConnecting indicate the peer is in connecting state // StatusConnecting indicate the peer is in connecting state
StatusConnecting StatusConnecting
// StatusDisconnected indicate the peer is in disconnected state // StatusConnected indicate the peer is in connected state
StatusDisconnected StatusConnected
) )
// ConnStatus describe the status of a peer's connection // ConnStatus describe the status of a peer's connection
type ConnStatus int32 type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string { func (s ConnStatus) String() string {
switch s { switch s {
case StatusConnecting: case StatusConnecting:
return "Connecting" return "Connecting"
case StatusConnected: case StatusConnected:
return "Connected" return "Connected"
case StatusDisconnected: case StatusIdle:
return "Disconnected" return "Idle"
default: default:
log.Errorf("unknown status: %d", s) log.Errorf("unknown status: %d", s)
return "INVALID_PEER_CONNECTION_STATUS" return "INVALID_PEER_CONNECTION_STATUS"

View File

@@ -14,7 +14,7 @@ func TestConnStatus_String(t *testing.T) {
want string want string
}{ }{
{"StatusConnected", StatusConnected, "Connected"}, {"StatusConnected", StatusConnected, "Connected"},
{"StatusDisconnected", StatusDisconnected, "Disconnected"}, {"StatusIdle", StatusIdle, "Idle"},
{"StatusConnecting", StatusConnecting, "Connecting"}, {"StatusConnecting", StatusConnecting, "Connecting"},
} }
@@ -24,5 +24,4 @@ func TestConnStatus_String(t *testing.T) {
assert.Equal(t, got, table.want, "they should be equal") assert.Equal(t, got, table.want, "they should be equal")
}) })
} }
} }

View File

@@ -1,7 +1,6 @@
package peer package peer
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@@ -11,6 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
"github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -18,6 +18,8 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
var testDispatcher = dispatcher.NewConnectionDispatcher()
var connConf = ConnConfig{ var connConf = ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
@@ -48,7 +50,13 @@ func TestNewConn_interfaceFilter(t *testing.T) {
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
sd := ServiceDependencies{
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@@ -60,7 +68,13 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@@ -94,7 +108,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@@ -125,43 +145,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
tables := []struct {
name string
statusIce ConnStatus
statusRelay ConnStatus
want ConnStatus
}{
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
si := NewAtomicConnStatus()
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
})
}
}
func TestConn_presharedKey(t *testing.T) { func TestConn_presharedKey(t *testing.T) {
conn1 := Conn{ conn1 := Conn{

View File

@@ -0,0 +1,29 @@
package conntype
import (
"fmt"
)
const (
None ConnPriority = 0
Relay ConnPriority = 1
ICETurn ConnPriority = 2
ICEP2P ConnPriority = 3
)
type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case None:
return "None"
case Relay:
return "PriorityRelay"
case ICETurn:
return "PriorityICETurn"
case ICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}

View File

@@ -0,0 +1,52 @@
package dispatcher
import (
"sync"
"github.com/netbirdio/netbird/client/internal/peer/id"
)
type ConnectionListener struct {
OnConnected func(peerID id.ConnID)
OnDisconnected func(peerID id.ConnID)
}
type ConnectionDispatcher struct {
listeners map[*ConnectionListener]struct{}
mu sync.Mutex
}
func NewConnectionDispatcher() *ConnectionDispatcher {
return &ConnectionDispatcher{
listeners: make(map[*ConnectionListener]struct{}),
}
}
func (e *ConnectionDispatcher) AddListener(listener *ConnectionListener) {
e.mu.Lock()
defer e.mu.Unlock()
e.listeners[listener] = struct{}{}
}
func (e *ConnectionDispatcher) RemoveListener(listener *ConnectionListener) {
e.mu.Lock()
defer e.mu.Unlock()
delete(e.listeners, listener)
}
func (e *ConnectionDispatcher) NotifyConnected(peerConnID id.ConnID) {
e.mu.Lock()
defer e.mu.Unlock()
for listener := range e.listeners {
listener.OnConnected(peerConnID)
}
}
func (e *ConnectionDispatcher) NotifyDisconnected(peerConnID id.ConnID) {
e.mu.Lock()
defer e.mu.Unlock()
for listener := range e.listeners {
listener.OnDisconnected(peerConnID)
}
}

View File

@@ -8,10 +8,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
reconnectMaxElapsedTime = 30 * time.Minute
)
type isConnectedFunc func() bool type isConnectedFunc func() bool
// Guard is responsible for the reconnection logic. // Guard is responsible for the reconnection logic.
@@ -25,7 +21,6 @@ type isConnectedFunc func() bool
type Guard struct { type Guard struct {
Reconnect chan struct{} Reconnect chan struct{}
log *log.Entry log *log.Entry
isController bool
isConnectedOnAllWay isConnectedFunc isConnectedOnAllWay isConnectedFunc
timeout time.Duration timeout time.Duration
srWatcher *SRWatcher srWatcher *SRWatcher
@@ -33,11 +28,10 @@ type Guard struct {
iCEConnDisconnected chan struct{} iCEConnDisconnected chan struct{}
} }
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{ return &Guard{
Reconnect: make(chan struct{}, 1), Reconnect: make(chan struct{}, 1),
log: log, log: log,
isController: isController,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
timeout: timeout, timeout: timeout,
srWatcher: srWatcher, srWatcher: srWatcher,
@@ -46,12 +40,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
} }
} }
func (g *Guard) Start(ctx context.Context) { func (g *Guard) Start(ctx context.Context, eventCallback func()) {
if g.isController { g.reconnectLoopWithRetry(ctx, eventCallback)
g.reconnectLoopWithRetry(ctx)
} else {
g.listenForDisconnectEvents(ctx)
}
} }
func (g *Guard) SetRelayedConnDisconnected() { func (g *Guard) SetRelayedConnDisconnected() {
@@ -68,9 +58,9 @@ func (g *Guard) SetICEConnDisconnected() {
} }
} }
// reconnectLoopWithRetry periodically check (max 30 min) the connection status. // reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported // Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) { func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
waitForInitialConnectionTry(ctx) waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener() srReconnectedChan := g.srWatcher.NewListener()
@@ -93,7 +83,7 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
} }
if !g.isConnectedOnAllWay() { if !g.isConnectedOnAllWay() {
g.triggerOfferSending() callback()
} }
case <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
@@ -121,39 +111,12 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
} }
} }
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
g.log.Infof("start listen for reconnect events...")
for {
select {
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, triggering reconnect")
g.triggerOfferSending()
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE state changed, try to send new offer")
g.triggerOfferSending()
case <-srReconnectedChan:
g.triggerOfferSending()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1, RandomizationFactor: 0.1,
Multiplier: 2, Multiplier: 2,
MaxInterval: g.timeout, MaxInterval: g.timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)
@@ -164,13 +127,6 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
return ticker return ticker
} }
func (g *Guard) triggerOfferSending() {
select {
case g.Reconnect <- struct{}{}:
default:
}
}
// Give chance to the peer to establish the initial connection. // Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer // With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) { func waitForInitialConnectionTry(ctx context.Context) {

View File

@@ -43,7 +43,6 @@ type OfferAnswer struct {
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
ctx context.Context
log *log.Entry log *log.Entry
config ConnConfig config ConnConfig
signaler *Signaler signaler *Signaler
@@ -57,9 +56,8 @@ type Handshaker struct {
remoteAnswerCh chan OfferAnswer remoteAnswerCh chan OfferAnswer
} }
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker { func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
return &Handshaker{ return &Handshaker{
ctx: ctx,
log: log, log: log,
config: config, config: config,
signaler: signaler, signaler: signaler,
@@ -74,10 +72,10 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
h.onNewOfferListeners = append(h.onNewOfferListeners, offer) h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
} }
func (h *Handshaker) Listen() { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
h.log.Info("wait for remote offer confirmation") h.log.Info("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation() remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx)
if err != nil { if err != nil {
var connectionClosedError *ConnectionClosedError var connectionClosedError *ConnectionClosedError
if errors.As(err, &connectionClosedError) { if errors.As(err, &connectionClosedError) {
@@ -127,7 +125,7 @@ func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
} }
} }
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) { func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed // received confirmation from the remote peer -> ready to proceed
@@ -137,7 +135,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
return &remoteOfferAnswer, nil return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil return &remoteOfferAnswer, nil
case <-h.ctx.Done(): case <-ctx.Done():
// closed externally // closed externally
return nil, NewConnectionClosedError(h.config.Key) return nil, NewConnectionClosedError(h.config.Key)
} }

View File

@@ -0,0 +1,5 @@
package id
import "unsafe"
type ConnID unsafe.Pointer

View File

@@ -15,7 +15,7 @@ import (
type WGIface interface { type WGIface interface {
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address Address() wgaddr.Address
} }

View File

@@ -135,14 +135,15 @@ type NSGroupState struct {
// FullStatus contains the full state held by the Status instance // FullStatus contains the full state held by the Status instance
type FullStatus struct { type FullStatus struct {
Peers []State Peers []State
ManagementState ManagementState ManagementState ManagementState
SignalState SignalState SignalState SignalState
LocalPeerState LocalPeerState LocalPeerState LocalPeerState
RosenpassState RosenpassState RosenpassState RosenpassState
Relays []relay.ProbeResult Relays []relay.ProbeResult
NSGroupStates []NSGroupState NSGroupStates []NSGroupState
NumOfForwardingRules int NumOfForwardingRules int
LazyConnectionEnabled bool
} }
// Status holds a state of peers, signal, management connections and relays // Status holds a state of peers, signal, management connections and relays
@@ -164,6 +165,7 @@ type Status struct {
rosenpassPermissive bool rosenpassPermissive bool
nsGroupStates []NSGroupState nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
lazyConnectionEnabled bool
// To reduce the number of notification invocation this bool will be true when need to call the notification // To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@@ -219,7 +221,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
} }
// AddPeer adds peer to Daemon status map // AddPeer adds peer to Daemon status map
func (d *Status) AddPeer(peerPubKey string, fqdn string) error { func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -229,7 +231,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
} }
d.peers[peerPubKey] = State{ d.peers[peerPubKey] = State{
PubKey: peerPubKey, PubKey: peerPubKey,
ConnStatus: StatusDisconnected, IP: ip,
ConnStatus: StatusIdle,
FQDN: fqdn, FQDN: fqdn,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
@@ -511,9 +514,9 @@ func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
switch { switch {
case receivedConnStatus == StatusConnecting: case receivedConnStatus == StatusConnecting:
return true return true
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting: case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusConnecting:
return true return true
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected: case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusIdle:
return curr.IP != "" return curr.IP != ""
default: default:
return false return false
@@ -689,6 +692,12 @@ func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
d.rosenpassEnabled = rosenpassEnabled d.rosenpassEnabled = rosenpassEnabled
} }
func (d *Status) UpdateLazyConnection(enabled bool) {
d.mux.Lock()
defer d.mux.Unlock()
d.lazyConnectionEnabled = enabled
}
// MarkSignalDisconnected sets SignalState to disconnected // MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(err error) { func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Lock() d.mux.Lock()
@@ -761,6 +770,12 @@ func (d *Status) GetRosenpassState() RosenpassState {
} }
} }
func (d *Status) GetLazyConnection() bool {
d.mux.Lock()
defer d.mux.Unlock()
return d.lazyConnectionEnabled
}
func (d *Status) GetManagementState() ManagementState { func (d *Status) GetManagementState() ManagementState {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -872,12 +887,13 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
// GetFullStatus gets full status // GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus { func (d *Status) GetFullStatus() FullStatus {
fullStatus := FullStatus{ fullStatus := FullStatus{
ManagementState: d.GetManagementState(), ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(), SignalState: d.GetSignalState(),
Relays: d.GetRelayStates(), Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(), RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(), NSGroupStates: d.GetDNSStates(),
NumOfForwardingRules: len(d.ForwardingRules()), NumOfForwardingRules: len(d.ForwardingRules()),
LazyConnectionEnabled: d.GetLazyConnection(),
} }
d.mux.Lock() d.mux.Lock()

View File

@@ -10,22 +10,24 @@ import (
func TestAddPeer(t *testing.T) { func TestAddPeer(t *testing.T) {
key := "abc" key := "abc"
ip := "100.108.254.1"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
err := status.AddPeer(key, "abc.netbird") err := status.AddPeer(key, "abc.netbird", ip)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
_, exists := status.peers[key] _, exists := status.peers[key]
assert.True(t, exists, "value was found") assert.True(t, exists, "value was found")
err = status.AddPeer(key, "abc.netbird") err = status.AddPeer(key, "abc.netbird", ip)
assert.Error(t, err, "should return error on duplicate") assert.Error(t, err, "should return error on duplicate")
} }
func TestGetPeer(t *testing.T) { func TestGetPeer(t *testing.T) {
key := "abc" key := "abc"
ip := "100.108.254.1"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
err := status.AddPeer(key, "abc.netbird") err := status.AddPeer(key, "abc.netbird", ip)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
peerStatus, err := status.GetPeer(key) peerStatus, err := status.GetPeer(key)

View File

@@ -2,6 +2,7 @@ package peer
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"time" "time"
@@ -20,7 +21,7 @@ var (
) )
type WGInterfaceStater interface { type WGInterfaceStater interface {
GetStats(key string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
} }
type WGWatcher struct { type WGWatcher struct {
@@ -146,9 +147,13 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
} }
func (w *WGWatcher) wgState() (time.Time, error) { func (w *WGWatcher) wgState() (time.Time, error) {
wgState, err := w.wgIfaceStater.GetStats(w.peerKey) wgStates, err := w.wgIfaceStater.GetStats()
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
} }
wgState, ok := wgStates[w.peerKey]
if !ok {
return time.Time{}, fmt.Errorf("peer %s not found in WireGuard endpoints", w.peerKey)
}
return wgState.LastHandshake, nil return wgState.LastHandshake, nil
} }

View File

@@ -11,26 +11,11 @@ import (
) )
type MocWgIface struct { type MocWgIface struct {
initial bool stop bool
lastHandshake time.Time
stop bool
} }
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) { func (m *MocWgIface) GetStats() (map[string]configurer.WGStats, error) {
if !m.initial { return map[string]configurer.WGStats{}, nil
m.initial = true
return configurer.WGStats{}, nil
}
if !m.stop {
m.lastHandshake = time.Now()
}
stats := configurer.WGStats{
LastHandshake: m.lastHandshake,
}
return stats, nil
} }
func (m *MocWgIface) disconnect() { func (m *MocWgIface) disconnect() {

View File

@@ -0,0 +1,55 @@
package worker
import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const (
StatusDisconnected Status = iota
StatusConnected
)
type Status int32
func (s Status) String() string {
switch s {
case StatusDisconnected:
return "Disconnected"
case StatusConnected:
return "Connected"
default:
log.Errorf("unknown status: %d", s)
return "unknown"
}
}
// AtomicWorkerStatus is a thread-safe wrapper for worker status
type AtomicWorkerStatus struct {
status atomic.Int32
}
func NewAtomicStatus() *AtomicWorkerStatus {
acs := &AtomicWorkerStatus{}
acs.SetDisconnected()
return acs
}
// Get returns the current connection status
func (acs *AtomicWorkerStatus) Get() Status {
return Status(acs.status.Load())
}
func (acs *AtomicWorkerStatus) SetConnected() {
acs.status.Store(int32(StatusConnected))
}
func (acs *AtomicWorkerStatus) SetDisconnected() {
acs.status.Store(int32(StatusDisconnected))
}
// String returns the string representation of the current status
func (acs *AtomicWorkerStatus) String() string {
return acs.Get().String()
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -397,10 +398,10 @@ func isRelayed(pair *ice.CandidatePair) bool {
return false return false
} }
func selectedPriority(pair *ice.CandidatePair) ConnPriority { func selectedPriority(pair *ice.CandidatePair) conntype.ConnPriority {
if isRelayed(pair) { if isRelayed(pair) {
return connPriorityICETurn return conntype.ICETurn
} else { } else {
return connPriorityICEP2P return conntype.ICEP2P
} }
} }

View File

@@ -1,6 +1,7 @@
package peerstore package peerstore
import ( import (
"context"
"net/netip" "net/netip"
"sync" "sync"
@@ -79,6 +80,32 @@ func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
return p, true return p, true
} }
func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
// this can be blocked because of the connect open limiter semaphore
if err := p.Open(ctx); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
func (s *Store) PeerConnClose(pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
p.Close()
}
func (s *Store) PeersPubKey() []string { func (s *Store) PeersPubKey() []string {
s.peerConnsMu.RLock() s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock() defer s.peerConnsMu.RUnlock()

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/client/common"
) )
// PKCEAuthorizationFlow represents PKCE Authorization Flow information // PKCEAuthorizationFlow represents PKCE Authorization Flow information
@@ -41,6 +42,8 @@ type PKCEAuthProviderConfig struct {
ClientCertPair *tls.Certificate ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login // DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
} }
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
@@ -100,6 +103,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(), UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert, ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(), DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
}, },
} }

View File

@@ -3,7 +3,6 @@ package iface
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -18,5 +17,4 @@ type wgIfaceBase interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
} }

View File

@@ -32,6 +32,10 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0) nets := make([]string, 0)
for _, r := range clientRoutes { for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() {
continue
}
nets = append(nets, r.Network.String()) nets = append(nets, r.Network.String())
} }
sort.Strings(nets) sort.Strings(nets)

View File

@@ -1,6 +1,7 @@
package systemops package systemops
import ( import (
"fmt"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
@@ -15,6 +16,20 @@ type Nexthop struct {
Intf *net.Interface Intf *net.Interface
} }
// Equal checks if two nexthops are equal.
func (n Nexthop) Equal(other Nexthop) bool {
return n.IP == other.IP && (n.Intf == nil && other.Intf == nil ||
n.Intf != nil && other.Intf != nil && n.Intf.Index == other.Intf.Index)
}
// String returns a string representation of the nexthop.
func (n Nexthop) String() string {
if n.Intf == nil {
return n.IP.String()
}
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
}
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct { type SysOps struct {

View File

@@ -33,8 +33,7 @@ type RouteUpdateType int
type RouteUpdate struct { type RouteUpdate struct {
Type RouteUpdateType Type RouteUpdateType
Destination netip.Prefix Destination netip.Prefix
NextHop netip.Addr NextHop Nexthop
Interface *net.Interface
} }
// RouteMonitor provides a way to monitor changes in the routing table. // RouteMonitor provides a way to monitor changes in the routing table.
@@ -231,15 +230,15 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
intf, err := net.InterfaceByIndex(idx) intf, err := net.InterfaceByIndex(idx)
if err != nil { if err != nil {
log.Warnf("failed to get interface name for index %d: %v", idx, err) log.Warnf("failed to get interface name for index %d: %v", idx, err)
update.Interface = &net.Interface{ update.NextHop.Intf = &net.Interface{
Index: idx, Index: idx,
} }
} else { } else {
update.Interface = intf update.NextHop.Intf = intf
} }
} }
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.NextHop.Intf)
dest := parseIPPrefix(row.DestinationPrefix, idx) dest := parseIPPrefix(row.DestinationPrefix, idx)
if !dest.Addr().IsValid() { if !dest.Addr().IsValid() {
return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row) return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row)
@@ -262,7 +261,7 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
update.Type = updateType update.Type = updateType
update.Destination = dest update.Destination = dest
update.NextHop = nexthop update.NextHop.IP = nexthop
return update, nil return update, nil
} }

File diff suppressed because it is too large Load Diff

View File

@@ -94,7 +94,7 @@ message LoginRequest {
bytes customDNSAddress = 7; bytes customDNSAddress = 7;
bool isLinuxDesktopClient = 8; bool isUnixDesktopClient = 8;
string hostname = 9; string hostname = 9;
@@ -134,6 +134,7 @@ message LoginRequest {
// omits initialized empty slices due to omitempty tags // omits initialized empty slices due to omitempty tags
bool cleanDNSLabels = 27; bool cleanDNSLabels = 27;
optional bool lazyConnectionEnabled = 28;
} }
message LoginResponse { message LoginResponse {
@@ -274,6 +275,8 @@ message FullStatus {
int32 NumberOfForwardingRules = 8; int32 NumberOfForwardingRules = 8;
repeated SystemEvent events = 7; repeated SystemEvent events = 7;
bool lazyConnectionEnabled = 9;
} }
// Networks // Networks

View File

@@ -51,14 +51,16 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
} }
if req.GetUploadURL() == "" { if req.GetUploadURL() == "" {
return &proto.DebugBundleResponse{Path: path}, nil return &proto.DebugBundleResponse{Path: path}, nil
} }
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path) key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
if err != nil { if err != nil {
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
} }
log.Infof("debug bundle uploaded to %s with key %s", req.GetUploadURL(), key)
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
} }

View File

@@ -139,6 +139,7 @@ func (s *Server) Start() error {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String()) s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
if s.sessionWatcher == nil { if s.sessionWatcher == nil {
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder) s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
@@ -417,6 +418,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.latestConfigInput.DisableNotifications = msg.DisableNotifications s.latestConfigInput.DisableNotifications = msg.DisableNotifications
} }
if msg.LazyConnectionEnabled != nil {
inputConfig.LazyConnectionEnabled = msg.LazyConnectionEnabled
s.latestConfigInput.LazyConnectionEnabled = msg.LazyConnectionEnabled
}
s.mutex.Unlock() s.mutex.Unlock()
if msg.OptionalPreSharedKey != nil { if msg.OptionalPreSharedKey != nil {
@@ -446,7 +452,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
if msg.SetupKey == "" { if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsLinuxDesktopClient) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
if err != nil { if err != nil {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
return nil, err return nil, err
@@ -804,6 +810,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules) pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers { for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{ pbPeerState := &proto.PeerState{

View File

@@ -97,6 +97,7 @@ type OutputOverview struct {
NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"` NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
Events []SystemEventOutput `json:"events" yaml:"events"` Events []SystemEventOutput `json:"events" yaml:"events"`
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
} }
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview { func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
@@ -136,6 +137,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()), NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
Events: mapEvents(pbFullStatus.GetEvents()), Events: mapEvents(pbFullStatus.GetEvents()),
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
} }
if anon { if anon {
@@ -206,7 +208,7 @@ func mapPeers(
transferSent := int64(0) transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) { if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
continue continue
} }
if isPeerConnected { if isPeerConnected {
@@ -384,6 +386,11 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
} }
} }
lazyConnectionEnabledStatus := "false"
if overview.LazyConnectionEnabled {
lazyConnectionEnabledStatus = "true"
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS goos := runtime.GOOS
@@ -405,6 +412,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"NetBird IP: %s\n"+ "NetBird IP: %s\n"+
"Interface type: %s\n"+ "Interface type: %s\n"+
"Quantum resistance: %s\n"+ "Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"Networks: %s\n"+ "Networks: %s\n"+
"Forwarding rules: %d\n"+ "Forwarding rules: %d\n"+
"Peers count: %s\n", "Peers count: %s\n",
@@ -419,6 +427,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
interfaceIP, interfaceIP,
interfaceTypeString, interfaceTypeString,
rosenpassEnabledStatus, rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
networks, networks,
overview.NumberOfForwardingRules, overview.NumberOfForwardingRules,
peersCountString, peersCountString,
@@ -533,23 +542,13 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
return peersString return peersString
} }
func skipDetailByFilters( func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool {
peerState *proto.PeerState,
isConnected bool,
statusFilter string,
prefixNamesFilter []string,
prefixNamesFilterMap map[string]struct{},
ipsFilter map[string]struct{},
) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := true nameEval := true
if statusFilter != "" { if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter) if !strings.EqualFold(peerStatus, statusFilter) {
if lowerStatusFilter == "disconnected" && isConnected {
statusEval = true
} else if lowerStatusFilter == "connected" && !isConnected {
statusEval = true statusEval = true
} }
} }

View File

@@ -383,7 +383,8 @@ func TestParsingToJSON(t *testing.T) {
"error": "timeout" "error": "timeout"
} }
], ],
"events": [] "events": [],
"lazyConnectionEnabled": false
}` }`
// @formatter:on // @formatter:on
@@ -484,6 +485,7 @@ dnsServers:
enabled: false enabled: false
error: timeout error: timeout
events: [] events: []
lazyConnectionEnabled: false
` `
assert.Equal(t, expectedYAML, yaml) assert.Equal(t, expectedYAML, yaml)
@@ -548,6 +550,7 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16 NetBird IP: 192.168.178.100/16
Interface type: Kernel Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Lazy connection: false
Networks: 10.10.0.0/24 Networks: 10.10.0.0/24
Forwarding rules: 0 Forwarding rules: 0
Peers count: 2/2 Connected Peers count: 2/2 Connected
@@ -570,6 +573,7 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16 NetBird IP: 192.168.178.100/16
Interface type: Kernel Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Lazy connection: false
Networks: 10.10.0.0/24 Networks: 10.10.0.0/24
Forwarding rules: 0 Forwarding rules: 0
Peers count: 2/2 Connected Peers count: 2/2 Connected

View File

@@ -51,14 +51,19 @@ const (
) )
func main() { func main() {
daemonAddr, showSettings, showNetworks, errorMsg, saveLogsInFile := parseFlags() daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags()
// Initialize file logging if needed. // Initialize file logging if needed.
var logFile string
if saveLogsInFile { if saveLogsInFile {
if err := initLogFile(); err != nil { file, err := initLogFile()
if err != nil {
log.Errorf("error while initializing log: %v", err) log.Errorf("error while initializing log: %v", err)
return return
} }
logFile = file
} else {
_ = util.InitLog("trace", "console")
} }
// Create the Fyne application. // Create the Fyne application.
@@ -72,13 +77,13 @@ func main() {
} }
// Create the service client (this also builds the settings or networks UI if requested). // Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(daemonAddr, a, showSettings, showNetworks) client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showDebug)
// Watch for theme/settings changes to update the icon. // Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client) go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set. // Run in window mode if any UI flag was set.
if showSettings || showNetworks { if showSettings || showNetworks || showDebug {
a.Run() a.Run()
return return
} }
@@ -99,7 +104,7 @@ func main() {
} }
// parseFlags reads and returns all needed command-line flags. // parseFlags reads and returns all needed command-line flags.
func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg string, saveLogsInFile bool) { func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool, errorMsg string, saveLogsInFile bool) {
defaultDaemonAddr := "unix:///var/run/netbird.sock" defaultDaemonAddr := "unix:///var/run/netbird.sock"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
defaultDaemonAddr = "tcp://127.0.0.1:41731" defaultDaemonAddr = "tcp://127.0.0.1:41731"
@@ -107,25 +112,17 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg
flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
flag.BoolVar(&showSettings, "settings", false, "run settings window") flag.BoolVar(&showSettings, "settings", false, "run settings window")
flag.BoolVar(&showNetworks, "networks", false, "run networks window") flag.BoolVar(&showNetworks, "networks", false, "run networks window")
flag.BoolVar(&showDebug, "debug", false, "run debug window")
flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window") flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
tmpDir := "/tmp"
if runtime.GOOS == "windows" {
tmpDir = os.TempDir()
}
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))
flag.Parse() flag.Parse()
return return
} }
// initLogFile initializes logging into a file. // initLogFile initializes logging into a file.
func initLogFile() error { func initLogFile() (string, error) {
tmpDir := "/tmp" logFile := path.Join(os.TempDir(), fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
if runtime.GOOS == "windows" { return logFile, util.InitLog("trace", logFile)
tmpDir = os.TempDir()
}
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
return util.InitLog("trace", logFile)
} }
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon. // watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
@@ -168,9 +165,10 @@ var iconConnectingMacOS []byte
var iconErrorMacOS []byte var iconErrorMacOS []byte
type serviceClient struct { type serviceClient struct {
ctx context.Context ctx context.Context
addr string cancel context.CancelFunc
conn proto.DaemonServiceClient addr string
conn proto.DaemonServiceClient
icAbout []byte icAbout []byte
icConnected []byte icConnected []byte
@@ -195,6 +193,7 @@ type serviceClient struct {
mAllowSSH *systray.MenuItem mAllowSSH *systray.MenuItem
mAutoConnect *systray.MenuItem mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
mNotifications *systray.MenuItem mNotifications *systray.MenuItem
mAdvancedSettings *systray.MenuItem mAdvancedSettings *systray.MenuItem
mCreateDebugBundle *systray.MenuItem mCreateDebugBundle *systray.MenuItem
@@ -231,13 +230,14 @@ type serviceClient struct {
daemonVersion string daemonVersion string
updateIndicationLock sync.Mutex updateIndicationLock sync.Mutex
isUpdateIconActive bool isUpdateIconActive bool
showRoutes bool showNetworks bool
wRoutes fyne.Window wNetworks fyne.Window
eventManager *event.Manager eventManager *event.Manager
exitNodeMu sync.Mutex exitNodeMu sync.Mutex
mExitNodeItems []menuHandler mExitNodeItems []menuHandler
logFile string
} }
type menuHandler struct { type menuHandler struct {
@@ -248,25 +248,30 @@ type menuHandler struct {
// newServiceClient instance constructor // newServiceClient instance constructor
// //
// This constructor also builds the UI elements for the settings window. // This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient { func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient {
ctx, cancel := context.WithCancel(context.Background())
s := &serviceClient{ s := &serviceClient{
ctx: context.Background(), ctx: ctx,
cancel: cancel,
addr: addr, addr: addr,
app: a, app: a,
logFile: logFile,
sendNotification: false, sendNotification: false,
showAdvancedSettings: showSettings, showAdvancedSettings: showSettings,
showRoutes: showRoutes, showNetworks: showNetworks,
update: version.NewUpdate(), update: version.NewUpdate(),
} }
s.setNewIcons() s.setNewIcons()
if showSettings { switch {
case showSettings:
s.showSettingsUI() s.showSettingsUI()
return s case showNetworks:
} else if showRoutes {
s.showNetworksUI() s.showNetworksUI()
case showDebug:
s.showDebugUI()
} }
return s return s
@@ -313,6 +318,8 @@ func (s *serviceClient) updateIcon() {
func (s *serviceClient) showSettingsUI() { func (s *serviceClient) showSettingsUI() {
// add settings window UI elements. // add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings") s.wSettings = s.app.NewWindow("NetBird Settings")
s.wSettings.SetOnClosed(s.cancel)
s.iMngURL = widget.NewEntry() s.iMngURL = widget.NewEntry()
s.iAdminURL = widget.NewEntry() s.iAdminURL = widget.NewEntry()
s.iConfigFile = widget.NewEntry() s.iConfigFile = widget.NewEntry()
@@ -378,12 +385,12 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
s.adminURL = iAdminURL s.adminURL = iAdminURL
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
ManagementUrl: iMngURL, ManagementUrl: iMngURL,
AdminURL: iAdminURL, AdminURL: iAdminURL,
IsLinuxDesktopClient: runtime.GOOS == "linux", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
RosenpassPermissive: &s.sRosenpassPermissive.Checked, RosenpassPermissive: &s.sRosenpassPermissive.Checked,
InterfaceName: &s.iInterfaceName.Text, InterfaceName: &s.iInterfaceName.Text,
WireguardPort: &port, WireguardPort: &port,
} }
if s.iPreSharedKey.Text != censoredPreSharedKey { if s.iPreSharedKey.Text != censoredPreSharedKey {
@@ -410,7 +417,7 @@ func (s *serviceClient) login() error {
} }
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
IsLinuxDesktopClient: runtime.GOOS == "linux", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
}) })
if err != nil { if err != nil {
log.Errorf("login to management URL with: %v", err) log.Errorf("login to management URL with: %v", err)
@@ -626,6 +633,7 @@ func (s *serviceClient) onTrayReady() {
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false) s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false) s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false) s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable lazy connection", lazyConnMenuDescr, false)
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false) s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr) s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr)
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr) s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
@@ -685,111 +693,120 @@ func (s *serviceClient) onTrayReady() {
go s.eventManager.Start(s.ctx) go s.eventManager.Start(s.ctx)
go func() { go s.listenEvents()
for { }
select {
case <-s.mUp.ClickedCh:
s.mUp.Disable()
go func() {
defer s.mUp.Enable()
err := s.menuUpClick()
if err != nil {
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
case <-s.mDown.ClickedCh:
s.mDown.Disable()
go func() {
defer s.mDown.Enable()
err := s.menuDownClick()
if err != nil {
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
case <-s.mAllowSSH.ClickedCh:
if s.mAllowSSH.Checked() {
s.mAllowSSH.Uncheck()
} else {
s.mAllowSSH.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAutoConnect.ClickedCh:
if s.mAutoConnect.Checked() {
s.mAutoConnect.Uncheck()
} else {
s.mAutoConnect.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mEnableRosenpass.ClickedCh:
if s.mEnableRosenpass.Checked() {
s.mEnableRosenpass.Uncheck()
} else {
s.mEnableRosenpass.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAdvancedSettings.ClickedCh:
s.mAdvancedSettings.Disable()
go func() {
defer s.mAdvancedSettings.Enable()
defer s.getSrvConfig()
s.runSelfCommand("settings", "true")
}()
case <-s.mCreateDebugBundle.ClickedCh:
go func() {
if err := s.createAndOpenDebugBundle(); err != nil {
log.Errorf("Failed to create debug bundle: %v", err)
s.app.SendNotification(fyne.NewNotification("Error", "Failed to create debug bundle"))
}
}()
case <-s.mQuit.ClickedCh:
systray.Quit()
return
case <-s.mGitHub.ClickedCh:
err := openURL("https://github.com/netbirdio/netbird")
if err != nil {
log.Errorf("%s", err)
}
case <-s.mUpdate.ClickedCh:
err := openURL(version.DownloadUrl())
if err != nil {
log.Errorf("%s", err)
}
case <-s.mNetworks.ClickedCh:
s.mNetworks.Disable()
go func() {
defer s.mNetworks.Enable()
s.runSelfCommand("networks", "true")
}()
case <-s.mNotifications.ClickedCh:
if s.mNotifications.Checked() {
s.mNotifications.Uncheck()
} else {
s.mNotifications.Check()
}
if s.eventManager != nil {
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
}
func (s *serviceClient) listenEvents() {
for {
select {
case <-s.mUp.ClickedCh:
s.mUp.Disable()
go func() {
defer s.mUp.Enable()
err := s.menuUpClick()
if err != nil {
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
case <-s.mDown.ClickedCh:
s.mDown.Disable()
go func() {
defer s.mDown.Enable()
err := s.menuDownClick()
if err != nil {
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
case <-s.mAllowSSH.ClickedCh:
if s.mAllowSSH.Checked() {
s.mAllowSSH.Uncheck()
} else {
s.mAllowSSH.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAutoConnect.ClickedCh:
if s.mAutoConnect.Checked() {
s.mAutoConnect.Uncheck()
} else {
s.mAutoConnect.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mEnableRosenpass.ClickedCh:
if s.mEnableRosenpass.Checked() {
s.mEnableRosenpass.Uncheck()
} else {
s.mEnableRosenpass.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mLazyConnEnabled.ClickedCh:
if s.mLazyConnEnabled.Checked() {
s.mLazyConnEnabled.Uncheck()
} else {
s.mLazyConnEnabled.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAdvancedSettings.ClickedCh:
s.mAdvancedSettings.Disable()
go func() {
defer s.mAdvancedSettings.Enable()
defer s.getSrvConfig()
s.runSelfCommand("settings", "true")
}()
case <-s.mCreateDebugBundle.ClickedCh:
s.mCreateDebugBundle.Disable()
go func() {
defer s.mCreateDebugBundle.Enable()
s.runSelfCommand("debug", "true")
}()
case <-s.mQuit.ClickedCh:
systray.Quit()
return
case <-s.mGitHub.ClickedCh:
err := openURL("https://github.com/netbirdio/netbird")
if err != nil {
log.Errorf("%s", err)
}
case <-s.mUpdate.ClickedCh:
err := openURL(version.DownloadUrl())
if err != nil {
log.Errorf("%s", err)
}
case <-s.mNetworks.ClickedCh:
s.mNetworks.Disable()
go func() {
defer s.mNetworks.Enable()
s.runSelfCommand("networks", "true")
}()
case <-s.mNotifications.ClickedCh:
if s.mNotifications.Checked() {
s.mNotifications.Uncheck()
} else {
s.mNotifications.Check()
}
if s.eventManager != nil {
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
} }
}() }
} }
func (s *serviceClient) runSelfCommand(command, arg string) { func (s *serviceClient) runSelfCommand(command, arg string) {
proc, err := os.Executable() proc, err := os.Executable()
if err != nil { if err != nil {
log.Errorf("show %s failed with error: %v", command, err) log.Errorf("Error getting executable path: %v", err)
return return
} }
@@ -798,14 +815,48 @@ func (s *serviceClient) runSelfCommand(command, arg string) {
fmt.Sprintf("--daemon-addr=%s", s.addr), fmt.Sprintf("--daemon-addr=%s", s.addr),
) )
out, err := cmd.CombinedOutput() if out := s.attachOutput(cmd); out != nil {
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 { defer func() {
log.Errorf("start %s UI: %v, %s", command, err, string(out)) if err := out.Close(); err != nil {
log.Errorf("Error closing log file %s: %v", s.logFile, err)
}
}()
}
log.Printf("Running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, s.addr)
err = cmd.Run()
if err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
log.Printf("Command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode())
} else {
log.Printf("Failed to start/run command '%s %s': %v", command, arg, err)
}
return return
} }
if len(out) != 0 {
log.Infof("command %s executed: %s", command, string(out)) log.Printf("Command '%s %s' completed successfully.", command, arg)
}
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
if s.logFile == "" {
// attach child's streams to parent's streams
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return nil
} }
out, err := os.OpenFile(s.logFile, os.O_WRONLY|os.O_APPEND, 0)
if err != nil {
log.Errorf("Failed to open log file %s: %v", s.logFile, err)
return nil
}
cmd.Stdout = out
cmd.Stderr = out
return out
} }
func normalizedVersion(version string) string { func normalizedVersion(version string) string {
@@ -818,9 +869,7 @@ func normalizedVersion(version string) string {
// onTrayExit is called when the tray icon is closed. // onTrayExit is called when the tray icon is closed.
func (s *serviceClient) onTrayExit() { func (s *serviceClient) onTrayExit() {
for _, item := range s.mExitNodeItems { s.cancel()
item.cancel()
}
} }
// getSrvClient connection to the service. // getSrvClient connection to the service.
@@ -829,7 +878,7 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil return s.conn, nil
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(s.ctx, timeout)
defer cancel() defer cancel()
conn, err := grpc.DialContext( conn, err := grpc.DialContext(
@@ -983,13 +1032,15 @@ func (s *serviceClient) updateConfig() error {
sshAllowed := s.mAllowSSH.Checked() sshAllowed := s.mAllowSSH.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked() rosenpassEnabled := s.mEnableRosenpass.Checked()
notificationsDisabled := !s.mNotifications.Checked() notificationsDisabled := !s.mNotifications.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
IsLinuxDesktopClient: runtime.GOOS == "linux", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ServerSSHAllowed: &sshAllowed, ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled, RosenpassEnabled: &rosenpassEnabled,
DisableAutoConnect: &disableAutoStart, DisableAutoConnect: &disableAutoStart,
DisableNotifications: &notificationsDisabled, DisableNotifications: &notificationsDisabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
} }
if err := s.restartClient(&loginRequest); err != nil { if err := s.restartClient(&loginRequest); err != nil {

View File

@@ -5,6 +5,7 @@ const (
allowSSHMenuDescr = "Allow SSH connections" allowSSHMenuDescr = "Allow SSH connections"
autoConnectMenuDescr = "Connect automatically when the service starts" autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass" quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connect"
notificationsMenuDescr = "Enable notifications" notificationsMenuDescr = "Enable notifications"
advancedSettingsMenuDescr = "Advanced settings of the application" advancedSettingsMenuDescr = "Advanced settings of the application"
debugBundleMenuDescr = "Create and open debug information bundle" debugBundleMenuDescr = "Create and open debug information bundle"

View File

@@ -3,48 +3,721 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"path/filepath" "path/filepath"
"strconv"
"sync"
"time"
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types"
) )
func (s *serviceClient) createAndOpenDebugBundle() error { // Initial state for the debug collection
type debugInitialState struct {
wasDown bool
logLevel proto.LogLevel
isLevelTrace bool
}
// Debug collection parameters
type debugCollectionParams struct {
duration time.Duration
anonymize bool
systemInfo bool
upload bool
uploadURL string
enablePersistence bool
}
// UI components for progress tracking
type progressUI struct {
statusLabel *widget.Label
progressBar *widget.ProgressBar
uiControls []fyne.Disableable
window fyne.Window
}
func (s *serviceClient) showDebugUI() {
w := s.app.NewWindow("NetBird Debug")
w.SetOnClosed(s.cancel)
w.Resize(fyne.NewSize(600, 500))
w.SetFixedSize(true)
anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil)
systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil)
systemInfoCheck.SetChecked(true)
uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil)
uploadCheck.SetChecked(true)
uploadURLLabel := widget.NewLabel("Debug upload URL:")
uploadURL := widget.NewEntry()
uploadURL.SetText(uptypes.DefaultBundleURL)
uploadURL.SetPlaceHolder("Enter upload URL")
uploadURLContainer := container.NewVBox(
uploadURLLabel,
uploadURL,
)
uploadCheck.OnChanged = func(checked bool) {
if checked {
uploadURLContainer.Show()
} else {
uploadURLContainer.Hide()
}
}
debugModeContainer := container.NewHBox()
runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil)
runForDurationCheck.SetChecked(true)
forLabel := widget.NewLabel("for")
durationInput := widget.NewEntry()
durationInput.SetText("1")
minutesLabel := widget.NewLabel("minute")
durationInput.Validator = func(s string) error {
return validateMinute(s, minutesLabel)
}
noteLabel := widget.NewLabel("Note: NetBird will be brought up and down during collection")
runForDurationCheck.OnChanged = func(checked bool) {
if checked {
forLabel.Show()
durationInput.Show()
minutesLabel.Show()
noteLabel.Show()
} else {
forLabel.Hide()
durationInput.Hide()
minutesLabel.Hide()
noteLabel.Hide()
}
}
debugModeContainer.Add(runForDurationCheck)
debugModeContainer.Add(forLabel)
debugModeContainer.Add(durationInput)
debugModeContainer.Add(minutesLabel)
statusLabel := widget.NewLabel("")
statusLabel.Hide()
progressBar := widget.NewProgressBar()
progressBar.Hide()
createButton := widget.NewButton("Create Debug Bundle", nil)
// UI controls that should be disabled during debug collection
uiControls := []fyne.Disableable{
anonymizeCheck,
systemInfoCheck,
uploadCheck,
uploadURL,
runForDurationCheck,
durationInput,
createButton,
}
createButton.OnTapped = s.getCreateHandler(
statusLabel,
progressBar,
uploadCheck,
uploadURL,
anonymizeCheck,
systemInfoCheck,
runForDurationCheck,
durationInput,
uiControls,
w,
)
content := container.NewVBox(
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
widget.NewLabel(""),
anonymizeCheck,
systemInfoCheck,
uploadCheck,
uploadURLContainer,
widget.NewLabel(""),
debugModeContainer,
noteLabel,
widget.NewLabel(""),
statusLabel,
progressBar,
createButton,
)
paddedContent := container.NewPadded(content)
w.SetContent(paddedContent)
w.Show()
}
func validateMinute(s string, minutesLabel *widget.Label) error {
if val, err := strconv.Atoi(s); err != nil || val < 1 {
return fmt.Errorf("must be a number ≥ 1")
}
if s == "1" {
minutesLabel.SetText("minute")
} else {
minutesLabel.SetText("minutes")
}
return nil
}
// disableUIControls disables the provided UI controls
func disableUIControls(controls []fyne.Disableable) {
for _, control := range controls {
control.Disable()
}
}
// enableUIControls enables the provided UI controls
func enableUIControls(controls []fyne.Disableable) {
for _, control := range controls {
control.Enable()
}
}
func (s *serviceClient) getCreateHandler(
statusLabel *widget.Label,
progressBar *widget.ProgressBar,
uploadCheck *widget.Check,
uploadURL *widget.Entry,
anonymizeCheck *widget.Check,
systemInfoCheck *widget.Check,
runForDurationCheck *widget.Check,
duration *widget.Entry,
uiControls []fyne.Disableable,
w fyne.Window,
) func() {
return func() {
disableUIControls(uiControls)
statusLabel.Show()
var url string
if uploadCheck.Checked {
url = uploadURL.Text
if url == "" {
statusLabel.SetText("Error: Upload URL is required when upload is enabled")
enableUIControls(uiControls)
return
}
}
params := &debugCollectionParams{
anonymize: anonymizeCheck.Checked,
systemInfo: systemInfoCheck.Checked,
upload: uploadCheck.Checked,
uploadURL: url,
enablePersistence: true,
}
runForDuration := runForDurationCheck.Checked
if runForDuration {
minutes, err := time.ParseDuration(duration.Text + "m")
if err != nil {
statusLabel.SetText(fmt.Sprintf("Error: Invalid duration: %v", err))
enableUIControls(uiControls)
return
}
params.duration = minutes
statusLabel.SetText(fmt.Sprintf("Running in debug mode for %d minutes...", int(minutes.Minutes())))
progressBar.Show()
progressBar.SetValue(0)
go s.handleRunForDuration(
statusLabel,
progressBar,
uiControls,
w,
params,
)
return
}
statusLabel.SetText("Creating debug bundle...")
go s.handleDebugCreation(
anonymizeCheck.Checked,
systemInfoCheck.Checked,
uploadCheck.Checked,
url,
statusLabel,
uiControls,
w,
)
}
}
func (s *serviceClient) handleRunForDuration(
statusLabel *widget.Label,
progressBar *widget.ProgressBar,
uiControls []fyne.Disableable,
w fyne.Window,
params *debugCollectionParams,
) {
progressUI := &progressUI{
statusLabel: statusLabel,
progressBar: progressBar,
uiControls: uiControls,
window: w,
}
conn, err := s.getSrvClient(failFastTimeout) conn, err := s.getSrvClient(failFastTimeout)
if err != nil { if err != nil {
return fmt.Errorf("get client: %v", err) handleError(progressUI, fmt.Sprintf("Failed to get client for debug: %v", err))
return
}
initialState, err := s.getInitialState(conn)
if err != nil {
handleError(progressUI, err.Error())
return
}
statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI)
if err != nil {
handleError(progressUI, err.Error())
return
}
if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil {
handleError(progressUI, err.Error())
return
}
s.restoreServiceState(conn, initialState)
progressUI.statusLabel.SetText("Bundle created successfully")
}
// Get initial state of the service
func (s *serviceClient) getInitialState(conn proto.DaemonServiceClient) (*debugInitialState, error) {
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
return nil, fmt.Errorf(" get status: %v", err)
}
logLevelResp, err := conn.GetLogLevel(s.ctx, &proto.GetLogLevelRequest{})
if err != nil {
return nil, fmt.Errorf("get log level: %v", err)
}
wasDown := statusResp.Status != string(internal.StatusConnected) &&
statusResp.Status != string(internal.StatusConnecting)
initialLogLevel := logLevelResp.GetLevel()
initialLevelTrace := initialLogLevel >= proto.LogLevel_TRACE
return &debugInitialState{
wasDown: wasDown,
logLevel: initialLogLevel,
isLevelTrace: initialLevelTrace,
}, nil
}
// Handle progress tracking during collection
func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time.Duration, progress *progressUI) {
progress.progressBar.Show()
progress.progressBar.SetValue(0)
startTime := time.Now()
endTime := startTime.Add(duration)
wg.Add(1)
go func() {
defer wg.Done()
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
remaining := time.Until(endTime)
if remaining <= 0 {
remaining = 0
}
elapsed := time.Since(startTime)
progressVal := float64(elapsed) / float64(duration)
if progressVal > 1.0 {
progressVal = 1.0
}
progress.progressBar.SetValue(progressVal)
progress.statusLabel.SetText(fmt.Sprintf("Running with trace logs... %s remaining", formatDuration(remaining)))
}
}
}()
}
func (s *serviceClient) configureServiceForDebug(
conn proto.DaemonServiceClient,
state *debugInitialState,
enablePersistence bool,
) error {
if state.wasDown {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service up: %v", err)
}
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
}
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
return fmt.Errorf("set log level to TRACE: %v", err)
}
log.Info("Log level set to TRACE for debug")
}
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
return fmt.Errorf("bring service down: %v", err)
}
time.Sleep(time.Second)
if enablePersistence {
if _, err := conn.SetNetworkMapPersistence(s.ctx, &proto.SetNetworkMapPersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("enable network map persistence: %v", err)
}
log.Info("Network map persistence enabled for debug")
}
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service back up: %v", err)
}
time.Sleep(time.Second * 3)
return nil
}
func (s *serviceClient) collectDebugData(
conn proto.DaemonServiceClient,
state *debugInitialState,
params *debugCollectionParams,
progress *progressUI,
) (string, error) {
ctx, cancel := context.WithTimeout(s.ctx, params.duration)
defer cancel()
var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
return "", err
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get post-up status: %v", err)
}
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
wg.Wait()
progress.progressBar.Hide()
progress.statusLabel.SetText("Collecting debug data...")
preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get pre-down status: %v", err)
}
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput)
return statusOutput, nil
}
// Create the debug bundle with collected data
func (s *serviceClient) createDebugBundleFromCollection(
conn proto.DaemonServiceClient,
params *debugCollectionParams,
statusOutput string,
progress *progressUI,
) error {
progress.statusLabel.SetText("Creating debug bundle with collected logs...")
request := &proto.DebugBundleRequest{
Anonymize: params.anonymize,
Status: statusOutput,
SystemInfo: params.systemInfo,
}
if params.upload {
request.UploadURL = params.uploadURL
}
resp, err := conn.DebugBundle(s.ctx, request)
if err != nil {
return fmt.Errorf("create debug bundle: %v", err)
}
// Show appropriate dialog based on upload status
localPath := resp.GetPath()
uploadFailureReason := resp.GetUploadFailureReason()
uploadedKey := resp.GetUploadedKey()
if params.upload {
if uploadFailureReason != "" {
showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
} else {
showUploadSuccessDialog(progress.window, localPath, uploadedKey)
}
} else {
showBundleCreatedDialog(progress.window, localPath)
}
enableUIControls(progress.uiControls)
return nil
}
// Restore service to original state
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
if state.wasDown {
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("Failed to restore down state: %v", err)
} else {
log.Info("Service state restored to down")
}
}
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
log.Errorf("Failed to restore log level: %v", err)
} else {
log.Info("Log level restored to original setting")
}
}
}
// Handle errors during debug collection
func handleError(progress *progressUI, errMsg string) {
log.Errorf("%s", errMsg)
progress.statusLabel.SetText(errMsg)
progress.progressBar.Hide()
enableUIControls(progress.uiControls)
}
func (s *serviceClient) handleDebugCreation(
anonymize bool,
systemInfo bool,
upload bool,
uploadURL string,
statusLabel *widget.Label,
uiControls []fyne.Disableable,
w fyne.Window,
) {
log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...",
anonymize, systemInfo, upload)
resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL)
if err != nil {
log.Errorf("Failed to create debug bundle: %v", err)
statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err))
enableUIControls(uiControls)
return
}
localPath := resp.GetPath()
uploadFailureReason := resp.GetUploadFailureReason()
uploadedKey := resp.GetUploadedKey()
if upload {
if uploadFailureReason != "" {
showUploadFailedDialog(w, localPath, uploadFailureReason)
} else {
showUploadSuccessDialog(w, localPath, uploadedKey)
}
} else {
showBundleCreatedDialog(w, localPath)
}
enableUIControls(uiControls)
statusLabel.SetText("Bundle created successfully")
}
func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploadURL string) (*proto.DebugBundleResponse, error) {
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return nil, fmt.Errorf("get client: %v", err)
} }
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
return fmt.Errorf("failed to get status: %v", err) log.Warnf("failed to get status for debug bundle: %v", err)
} }
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, true, "", nil, nil, nil) var statusOutput string
statusOutput := nbstatus.ParseToFullDetailSummary(overview) if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil)
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
resp, err := conn.DebugBundle(s.ctx, &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: true, Anonymize: anonymize,
Status: statusOutput, Status: statusOutput,
SystemInfo: true, SystemInfo: systemInfo,
}) }
if uploadURL != "" {
request.UploadURL = uploadURL
}
resp, err := conn.DebugBundle(s.ctx, request)
if err != nil { if err != nil {
return fmt.Errorf("failed to create debug bundle: %v", err) return nil, fmt.Errorf("failed to create debug bundle via daemon: %v", err)
} }
bundleDir := filepath.Dir(resp.GetPath()) return resp, nil
if err := open.Start(bundleDir); err != nil { }
return fmt.Errorf("failed to open debug bundle directory: %v", err)
} // formatDuration formats a duration in HH:MM:SS format
func formatDuration(d time.Duration) string {
s.app.SendNotification(fyne.NewNotification( d = d.Round(time.Second)
"Debug Bundle", h := d / time.Hour
fmt.Sprintf("Debug bundle created at %s. Administrator privileges are required to access it.", resp.GetPath()), d %= time.Hour
)) m := d / time.Minute
d %= time.Minute
return nil s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
// createButtonWithAction creates a button with the given label and action
func createButtonWithAction(label string, action func()) *widget.Button {
button := widget.NewButton(label, action)
return button
}
// showUploadFailedDialog displays a dialog when upload fails
func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) {
content := container.NewVBox(
widget.NewLabel(fmt.Sprintf("Bundle upload failed:\n%s\n\n"+
"A local copy was saved at:\n%s", failureReason, localPath)),
)
customDialog := dialog.NewCustom("Upload Failed", "Cancel", content, w)
buttonBox := container.NewHBox(
createButtonWithAction("Open file", func() {
log.Infof("Attempting to open local file: %s", localPath)
if openErr := open.Start(localPath); openErr != nil {
log.Errorf("Failed to open local file '%s': %v", localPath, openErr)
dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w)
}
}),
createButtonWithAction("Open folder", func() {
folderPath := filepath.Dir(localPath)
log.Infof("Attempting to open local folder: %s", folderPath)
if openErr := open.Start(folderPath); openErr != nil {
log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr)
dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w)
}
}),
)
content.Add(buttonBox)
customDialog.Show()
}
// showUploadSuccessDialog displays a dialog when upload succeeds
func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
log.Infof("Upload key: %s", uploadedKey)
keyEntry := widget.NewEntry()
keyEntry.SetText(uploadedKey)
keyEntry.Disable()
content := container.NewVBox(
widget.NewLabel("Bundle uploaded successfully!"),
widget.NewLabel(""),
widget.NewLabel("Upload key:"),
keyEntry,
widget.NewLabel(""),
widget.NewLabel(fmt.Sprintf("Local copy saved at:\n%s", localPath)),
)
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
copyBtn := createButtonWithAction("Copy key", func() {
w.Clipboard().SetContent(uploadedKey)
log.Info("Upload key copied to clipboard")
})
buttonBox := createButtonBox(localPath, w, copyBtn)
content.Add(buttonBox)
customDialog.Show()
}
// showBundleCreatedDialog displays a dialog when bundle is created without upload
func showBundleCreatedDialog(w fyne.Window, localPath string) {
content := container.NewVBox(
widget.NewLabel(fmt.Sprintf("Bundle created locally at:\n%s\n\n"+
"Administrator privileges may be required to access the file.", localPath)),
)
customDialog := dialog.NewCustom("Debug Bundle Created", "Cancel", content, w)
buttonBox := createButtonBox(localPath, w, nil)
content.Add(buttonBox)
customDialog.Show()
}
func createButtonBox(localPath string, w fyne.Window, elems ...fyne.Widget) *fyne.Container {
box := container.NewHBox()
for _, elem := range elems {
box.Add(elem)
}
fileBtn := createButtonWithAction("Open file", func() {
log.Infof("Attempting to open local file: %s", localPath)
if openErr := open.Start(localPath); openErr != nil {
log.Errorf("Failed to open local file '%s': %v", localPath, openErr)
dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w)
}
})
folderBtn := createButtonWithAction("Open folder", func() {
folderPath := filepath.Dir(localPath)
log.Infof("Attempting to open local folder: %s", folderPath)
if openErr := open.Start(folderPath); openErr != nil {
log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr)
dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w)
}
})
box.Add(fileBtn)
box.Add(folderBtn)
return box
} }

View File

@@ -34,7 +34,8 @@ const (
type filter string type filter string
func (s *serviceClient) showNetworksUI() { func (s *serviceClient) showNetworksUI() {
s.wRoutes = s.app.NewWindow("Networks") s.wNetworks = s.app.NewWindow("Networks")
s.wNetworks.SetOnClosed(s.cancel)
allGrid := container.New(layout.NewGridLayout(3)) allGrid := container.New(layout.NewGridLayout(3))
go s.updateNetworks(allGrid, allNetworks) go s.updateNetworks(allGrid, allNetworks)
@@ -78,8 +79,8 @@ func (s *serviceClient) showNetworksUI() {
content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer) content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer)
s.wRoutes.SetContent(content) s.wNetworks.SetContent(content)
s.wRoutes.Show() s.wNetworks.Show()
s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid) s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
} }
@@ -148,7 +149,7 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) {
grid.Add(resolvedIPsSelector) grid.Add(resolvedIPsSelector)
} }
s.wRoutes.Content().Refresh() s.wNetworks.Content().Refresh()
grid.Refresh() grid.Refresh()
} }
@@ -305,7 +306,7 @@ func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.Se
func (s *serviceClient) showError(err error) { func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50) wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes) dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wNetworks)
} }
func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
@@ -316,14 +317,15 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container
} }
}() }()
s.wRoutes.SetOnClosed(func() { s.wNetworks.SetOnClosed(func() {
ticker.Stop() ticker.Stop()
s.cancel()
}) })
} }
func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid) grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
s.wRoutes.Content().Refresh() s.wNetworks.Content().Refresh()
s.updateNetworks(grid, f) s.updateNetworks(grid, f)
} }
@@ -373,7 +375,7 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
node.Selected, node.Selected,
) )
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(s.ctx)
s.mExitNodeItems = append(s.mExitNodeItems, menuHandler{ s.mExitNodeItems = append(s.mExitNodeItems, menuHandler{
MenuItem: menuItem, MenuItem: menuItem,
cancel: cancel, cancel: cancel,

View File

@@ -66,17 +66,17 @@ func (s SimpleRecord) String() string {
func (s SimpleRecord) Len() uint16 { func (s SimpleRecord) Len() uint16 {
emptyString := s.RData == "" emptyString := s.RData == ""
switch s.Type { switch s.Type {
case 1: case int(dns.TypeA):
if emptyString { if emptyString {
return 0 return 0
} }
return net.IPv4len return net.IPv4len
case 5: case int(dns.TypeCNAME):
if emptyString || s.RData == "." { if emptyString || s.RData == "." {
return 1 return 1
} }
return uint16(len(s.RData) + 1) return uint16(len(s.RData) + 1)
case 28: case int(dns.TypeAAAA):
if emptyString { if emptyString {
return 0 return 0
} }

4
go.mod
View File

@@ -59,13 +59,12 @@ require (
github.com/hashicorp/go-version v1.6.0 github.com/hashicorp/go-version v1.6.0
github.com/libdns/route53 v1.5.0 github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1 github.com/libp2p/go-netroute v0.2.1
github.com/mattn/go-sqlite3 v1.14.22
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203 github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -195,6 +194,7 @@ require (
github.com/libdns/libdns v0.2.2 // indirect github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.7 // indirect github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect

4
go.sum
View File

@@ -507,8 +507,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs= github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ= github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=

View File

@@ -59,6 +59,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"} NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false} NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false} NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1}
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
# Dashboard # Dashboard
@@ -122,6 +123,7 @@ export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
export NETBIRD_AUTH_PKCE_LOGIN_FLAG
export NETBIRD_AUTH_PKCE_AUDIENCE export NETBIRD_AUTH_PKCE_AUDIENCE
export NETBIRD_DASH_AUTH_USE_AUDIENCE export NETBIRD_DASH_AUTH_USE_AUDIENCE
export NETBIRD_DASH_AUTH_AUDIENCE export NETBIRD_DASH_AUTH_AUDIENCE

View File

@@ -95,7 +95,8 @@
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES", "Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS], "RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN, "UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN "DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
} }
} }
} }

View File

@@ -28,3 +28,4 @@ NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH
NETBIRD_TURN_EXTERNAL_IP=1.2.3.4 NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
NETBIRD_RELAY_PORT=33445 NETBIRD_RELAY_PORT=33445
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0

View File

@@ -0,0 +1,19 @@
package common
// LoginFlag introduces additional login flags to the PKCE authorization request
type LoginFlag uint8
const (
// LoginFlagPrompt adds prompt=login to the authorization request
LoginFlagPrompt LoginFlag = iota
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
LoginFlagMaxAge0
)
func (l LoginFlag) IsPromptLogin() bool {
return l == LoginFlagPrompt
}
func (l LoginFlag) IsMaxAge0Login() bool {
return l == LoginFlagMaxAge0
}

Some files were not shown because too many files have changed in this diff Show More