Compare commits

..

19 Commits

Author SHA1 Message Date
Maycon Santos
96f71ff1e1 [misc] Update tag name extraction in install.sh (#4677) 2025-10-21 19:23:11 +02:00
Viktor Liu
2fe2af38d2 [client] Clean up match domain reg entries between config changes (#4676) 2025-10-21 18:14:39 +02:00
Misha Bragin
cd9a867ad0 [client] Delete TURNConfig section from script (#4639) 2025-10-17 19:48:26 +02:00
Maycon Santos
0f9bfeff7c [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 2025-10-17 19:47:11 +02:00
Viktor Liu
f5301230bf [client] Fix status showing P2P without connection (#4661) 2025-10-17 13:31:15 +02:00
Viktor Liu
429d7d6585 [client] Support BROWSER env for login (#4654) 2025-10-17 11:10:16 +02:00
Viktor Liu
3cdb10cde7 [client] Remove rule squashing (#4653) 2025-10-17 11:09:39 +02:00
Zoltan Papp
af95aabb03 Handle the case when the service has already been down and the status recorder is not available (#4652) 2025-10-16 17:15:39 +02:00
Viktor Liu
3abae0bd17 [client] Set default wg port for new profiles (#4651) 2025-10-16 16:16:51 +02:00
Viktor Liu
8252ff41db [client] Add bind activity listener to bypass udp sockets (#4646) 2025-10-16 15:58:29 +02:00
Viktor Liu
277aa2b7cc [client] Fix missing flag values in profiles (#4650) 2025-10-16 15:13:41 +02:00
John Conley
bb37dc89ce [management] feat: Basic PocketID IDP integration (#4529) 2025-10-16 10:46:29 +02:00
Viktor Liu
000e99e7f3 [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) 2025-10-13 17:50:16 +02:00
Maycon Santos
0d2e67983a [misc] Add service definition for netbird-signal (#4620) 2025-10-10 19:16:48 +02:00
Pascal Fischer
5151f19d29 [management] pass temporary flag to validator (#4599) 2025-10-10 16:15:51 +02:00
Kostya Leschenko
bedd3cabc9 [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) 2025-10-10 15:24:24 +02:00
hakansa
d35a845dbd [management] sync all other peers on peer add/remove (#4614) 2025-10-09 21:18:00 +02:00
hakansa
4e03f708a4 fix dns forwarder port update (#4613)
fix dns forwarder port update (#4613)
2025-10-09 17:39:02 +03:00
Ashley
654aa9581d [client,gui] Update url_windows.go to offer arm64 executable download (#4586) 2025-10-08 21:27:32 +02:00
73 changed files with 2252 additions and 1427 deletions

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0
FROM alpine:3.22.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
@@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("")
if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil {
if err := openBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
srv, err := sig.NewServer(context.Background(), otel.Meter(""), nil)
srv, err := sig.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)

View File

@@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ""
}
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"

View File

@@ -23,4 +23,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *WGTunDevice) GetICEBind() EndpointManager {
return t.iceBind
}
func routesToString(routes []string) string {
return strings.Join(routes, ";")
}

View File

@@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error {
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns nil for kernel mode devices
func (t *TunKernelDevice) GetICEBind() EndpointManager {
return nil
}

View File

@@ -21,6 +21,7 @@ type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
EndpointManager
}
type TunNetstackDevice struct {
@@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device {
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}
// GetICEBind returns the bind instance
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
return t.bind
}

View File

@@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error {
func (t *USPDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -0,0 +1,13 @@
package device
import (
"net"
"net/netip"
)
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
// Implemented by bind.ICEBind and bind.RelayBindJS.
type EndpointManager interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
}

View File

@@ -21,4 +21,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun == nil {
return nil
}
return w.tun.GetICEBind()
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind

View File

@@ -29,11 +29,6 @@ type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
}
type protoMatch struct {
ips map[string]int
policyID []byte
}
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
@@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap)
rules := networkMap.FirewallRules
enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
// if TCP protocol rules not squashed and SSH enabled
// we add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP 22 port).
// If SSH enabled, add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
@@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID(
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic.
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
totalIPs++
}
}
in := map[mgmProto.RuleProtocol]*protoMatch{}
out := map[mgmProto.RuleProtocol]*protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if:
// 1. Any of the rule has DROP action type.
// 2. Any of rule contains Port.
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
}
// special case, when we receive this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return
}
ipset := protocols[r.Protocol].ips
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
}
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
}
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
protocolOrders := []mgmProto.RuleProtocol{
mgmProto.RuleProtocol_ALL,
mgmProto.RuleProtocol_ICMP,
mgmProto.RuleProtocol_TCP,
mgmProto.RuleProtocol_UDP,
}
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
// don't squash if :
// 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network.
continue
}
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
PolicyID: match.policyID,
})
squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
}
}
}
squash(in, mgmProto.RuleDirection_IN)
squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
if len(squashedRules) == 0 {
return networkMap.FirewallRules, squashedProtocols
}
var rules []*mgmProto.FirewallRule
// filter out rules which was squashed from final list
// if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
}
}
rules = append(rules, r)
}
return append(rules, squashedRules...), squashedProtocols
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)

View File

@@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
func TestDefaultManagerSquashRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
}
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -14,6 +14,9 @@ type WGIface interface {
}
func (g *BundleGenerator) addWgShow() error {
if g.statusRecorder == nil {
return fmt.Errorf("no status recorder available for wg show")
}
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err

View File

@@ -17,6 +17,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/winregistry"
)
var (
@@ -197,6 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
}
if err := r.removeDNSMatchPolicies(); err != nil {
log.Errorf("cleanup old dns match policies: %s", err)
}
if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil {
@@ -204,9 +209,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
}
r.nrptEntryCount = count
} else {
if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err)
}
r.nrptEntryCount = 0
}
@@ -273,9 +275,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
return fmt.Errorf("remove existing dns policy: %w", err)
}
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
}
defer closer(regKey)

View File

@@ -0,0 +1,102 @@
package dns
import (
"fmt"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/registry"
)
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config5, nil)
require.NoError(t, err)
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after first config", i)
}
config2 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config2, nil)
require.NoError(t, err)
// Verify first 2 entries exist
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after second config", i)
}
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
}
func registryKeyExists(path string) (bool, error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
if err == registry.ErrNotExist {
return false, nil
}
return false, err
}
k.Close()
return true, nil
}
func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10}
_ = cfg.removeDNSMatchPolicies()
}

View File

@@ -31,6 +31,7 @@ const (
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
log.Warnf("failed to set DNSSEC to 'no': %v", err)
}
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
}
var (
searchDomains []string
matchDomains []string

View File

@@ -40,7 +40,6 @@ type Manager struct {
fwRules []firewall.Rule
tcpRules []firewall.Rule
dnsForwarder *DNSForwarder
port uint16
}
func ListenPort() uint16 {
@@ -49,11 +48,16 @@ func ListenPort() uint16 {
return listenPort
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
func SetListenPort(port uint16) {
listenPortMu.Lock()
listenPort = port
listenPortMu.Unlock()
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
return &Manager{
firewall: fw,
statusRecorder: statusRecorder,
port: port,
}
}
@@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err
}
if m.port > 0 {
listenPortMu.Lock()
listenPort = m.port
listenPortMu.Unlock()
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {

View File

@@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder(
return
}
if forwarderPort > 0 {
dnsfwd.SetListenPort(forwarderPort)
}
if !enabled {
if e.dnsForwardMgr == nil {
return
@@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder(
if len(fwdEntries) > 0 {
switch {
case e.dnsForwardMgr == nil:
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
@@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil

View File

@@ -1512,7 +1512,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""), nil)
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)

View File

@@ -0,0 +1,82 @@
package activity
import (
"context"
"io"
"net"
"time"
)
// lazyConn detects activity when WireGuard attempts to send packets.
// It does not deliver packets, only signals that activity occurred.
type lazyConn struct {
activityCh chan struct{}
ctx context.Context
cancel context.CancelFunc
}
// newLazyConn creates a new lazyConn for activity detection.
func newLazyConn() *lazyConn {
ctx, cancel := context.WithCancel(context.Background())
return &lazyConn{
activityCh: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
}
}
// Read blocks until the connection is closed.
func (c *lazyConn) Read(_ []byte) (n int, err error) {
<-c.ctx.Done()
return 0, io.EOF
}
// Write signals activity detection when ICEBind routes packets to this endpoint.
func (c *lazyConn) Write(b []byte) (n int, err error) {
if c.ctx.Err() != nil {
return 0, io.EOF
}
select {
case c.activityCh <- struct{}{}:
default:
}
return len(b), nil
}
// ActivityChan returns the channel that signals when activity is detected.
func (c *lazyConn) ActivityChan() <-chan struct{} {
return c.activityCh
}
// Close closes the connection.
func (c *lazyConn) Close() error {
c.cancel()
return nil
}
// LocalAddr returns the local address.
func (c *lazyConn) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
}
// RemoteAddr returns the remote address.
func (c *lazyConn) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
}
// SetDeadline sets the read and write deadlines.
func (c *lazyConn) SetDeadline(_ time.Time) error {
return nil
}
// SetReadDeadline sets the deadline for future Read calls.
func (c *lazyConn) SetReadDeadline(_ time.Time) error {
return nil
}
// SetWriteDeadline sets the deadline for future Write calls.
func (c *lazyConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@@ -0,0 +1,127 @@
package activity
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
type bindProvider interface {
GetBind() device.EndpointManager
}
const (
// lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers.
// The actual routing is done via fakeIP in ICEBind, not by this port.
lazyBindPort = 17473
)
// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode.
type BindListener struct {
wgIface WgInterface
peerCfg lazyconn.PeerConfig
done sync.WaitGroup
lazyConn *lazyConn
bind device.EndpointManager
fakeIP netip.Addr
}
// NewBindListener creates a listener that passes data directly through bind using LazyConn.
// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range.
func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) {
fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs)
if err != nil {
return nil, fmt.Errorf("derive fake IP: %w", err)
}
d := &BindListener{
wgIface: wgIface,
peerCfg: cfg,
bind: bind,
fakeIP: fakeIP,
}
if err := d.setupLazyConn(); err != nil {
return nil, fmt.Errorf("setup lazy connection: %v", err)
}
d.done.Add(1)
return d, nil
}
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
if len(allowedIPs) == 0 {
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
}
ourNetwork := wgIface.Address().Network
var peerIP netip.Addr
for _, allowedIP := range allowedIPs {
ip := allowedIP.Addr()
if !ip.Is4() {
continue
}
if ourNetwork.Contains(ip) {
peerIP = ip
break
}
}
if !peerIP.IsValid() {
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
}
octets := peerIP.As4()
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
return fakeIP, nil
}
func (d *BindListener) setupLazyConn() error {
d.lazyConn = newLazyConn()
d.bind.SetEndpoint(d.fakeIP, d.lazyConn)
endpoint := &net.UDPAddr{
IP: d.fakeIP.AsSlice(),
Port: lazyBindPort,
}
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil)
}
// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed.
func (d *BindListener) ReadPackets() {
select {
case <-d.lazyConn.ActivityChan():
d.peerCfg.Log.Infof("activity detected via LazyConn")
case <-d.lazyConn.ctx.Done():
d.peerCfg.Log.Infof("exit from activity listener")
}
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.lazyConn.Close()
d.bind.RemoveEndpoint(d.fakeIP)
d.done.Done()
}
// Close stops the listener and cleans up resources.
func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
if err := d.lazyConn.Close(); err != nil {
d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err)
}
d.done.Wait()
}

View File

@@ -0,0 +1,291 @@
package activity
import (
"net"
"net/netip"
"runtime"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn
}
func newMockEndpointManager() *mockEndpointManager {
return &mockEndpointManager{
endpoints: make(map[netip.Addr]net.Conn),
}
}
func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
m.endpoints[fakeIP] = conn
}
func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) {
delete(m.endpoints, fakeIP)
}
func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn {
return m.endpoints[fakeIP]
}
// MockWGIfaceBind mocks WgInterface with bind support
type MockWGIfaceBind struct {
endpointMgr *mockEndpointManager
}
func (m *MockWGIfaceBind) RemovePeer(string) error {
return nil
}
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
func (m *MockWGIfaceBind) IsUserspaceBind() bool {
return true
}
func (m *MockWGIfaceBind) Address() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
}
}
func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr
}
func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
expectedFakeIP := netip.MustParseAddr("127.2.0.2")
conn := mockEndpointMgr.GetEndpoint(expectedFakeIP)
require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager")
_, ok := conn.(*lazyConn)
assert.True(t, ok, "Registered endpoint should be a lazyConn")
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
}
func TestBindListener_ActivityDetection(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
activityDetected := make(chan struct{})
go func() {
listener.ReadPackets()
close(activityDetected)
}()
fakeIP := listener.fakeIP
conn := mockEndpointMgr.GetEndpoint(fakeIP)
require.NotNil(t, conn, "Endpoint should be registered")
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case <-activityDetected:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity detection")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection")
}
func TestBindListener_Close(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
fakeIP := listener.fakeIP
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close")
}
func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
mgr := NewManager(mockIface)
defer mgr.Close()
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
err := mgr.MonitorPeerActivity(cfg)
require.NoError(t, err)
listener, exists := mgr.GetPeerListener(cfg.PeerConnID)
require.True(t, exists, "Peer listener should be found")
bindListener, ok := listener.(*BindListener)
require.True(t, ok, "Listener should be BindListener, got %T", listener)
fakeIP := bindListener.fakeIP
conn := mockEndpointMgr.GetEndpoint(fakeIP)
require.NotNil(t, conn, "Endpoint should be registered")
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity")
}
func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer1 := &MocPeer{PeerID: "testPeer1"}
peer2 := &MocPeer{PeerID: "testPeer2"}
mgr := NewManager(mockIface)
defer mgr.Close()
cfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
cfg2 := lazyconn.PeerConfig{
PublicKey: peer2.PeerID,
PeerConnID: peer2.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")},
Log: log.WithField("peer", "testPeer2"),
}
err := mgr.MonitorPeerActivity(cfg1)
require.NoError(t, err)
err = mgr.MonitorPeerActivity(cfg2)
require.NoError(t, err)
listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID)
require.True(t, exists, "Peer1 listener should be found")
bindListener1 := listener1.(*BindListener)
listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID)
require.True(t, exists, "Peer2 listener should be found")
bindListener2 := listener2.(*BindListener)
conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP)
require.NotNil(t, conn1, "Peer1 endpoint should be registered")
_, err = conn1.Write([]byte{0x01})
require.NoError(t, err)
conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP)
require.NotNil(t, conn2, "Peer2 endpoint should be registered")
_, err = conn2.Write([]byte{0x02})
require.NoError(t, err)
receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ {
select {
case peerConnID := <-mgr.OnActivityChan:
receivedPeers[peerConnID] = true
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications")
}
}
assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received")
assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received")
}

View File

@@ -1,41 +0,0 @@
package activity
import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestNewListener(t *testing.T) {
peer := &MocPeer{
PeerID: "examplePublicKey1",
}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
l, err := NewListener(MocWGIface{}, cfg)
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
chanClosed := make(chan struct{})
go func() {
defer close(chanClosed)
l.ReadPackets()
}()
time.Sleep(1 * time.Second)
l.Close()
select {
case <-chanClosed:
case <-time.After(time.Second):
}
}

View File

@@ -11,26 +11,27 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
// UDPListener uses UDP sockets for activity detection in kernel mode.
type UDPListener struct {
wgIface WgInterface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
done sync.Mutex
isClosed atomic.Bool // use to avoid error log when closing the listener
isClosed atomic.Bool
}
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
// NewUDPListener creates a listener that detects activity via UDP socket reads.
func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) {
d := &UDPListener{
wgIface: wgIface,
peerCfg: cfg,
}
conn, err := d.newConn()
if err != nil {
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
return nil, fmt.Errorf("create UDP connection: %v", err)
}
d.conn = conn
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
@@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error
if err := d.createEndpoint(); err != nil {
return nil, err
}
d.done.Lock()
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String())
return d, nil
}
func (d *Listener) ReadPackets() {
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
func (d *UDPListener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
@@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() {
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.removeEndpoint(); err != nil {
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
// Ignore close error as it may return "use of closed network connection" if already closed.
_ = d.conn.Close()
d.done.Unlock()
}
func (d *Listener) Close() {
// Close stops the listener and cleans up resources.
func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
@@ -82,16 +87,12 @@ func (d *Listener) Close() {
d.done.Lock()
}
func (d *Listener) removeEndpoint() error {
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}
func (d *Listener) createEndpoint() error {
func (d *UDPListener) createEndpoint() error {
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
}
func (d *Listener) newConn() (*net.UDPConn, error) {
func (d *UDPListener) newConn() (*net.UDPConn, error) {
addr := &net.UDPAddr{
Port: 0,
IP: listenIP,

View File

@@ -0,0 +1,110 @@
package activity
import (
"net"
"net/netip"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestUDPListener_Creation(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
require.NotNil(t, listener.conn)
require.NotNil(t, listener.endpoint)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
}
func TestUDPListener_ActivityDetection(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
activityDetected := make(chan struct{})
go func() {
listener.ReadPackets()
close(activityDetected)
}()
conn, err := net.Dial("udp", listener.conn.LocalAddr().String())
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case <-activityDetected:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity detection")
}
}
func TestUDPListener_Close(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed")
}

View File

@@ -1,21 +1,32 @@
package activity
import (
"errors"
"net"
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
// listener defines the contract for activity detection listeners.
type listener interface {
ReadPackets()
Close()
}
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
}
type Manager struct {
@@ -23,7 +34,7 @@ type Manager struct {
wgIface WgInterface
peers map[peerid.ConnID]*Listener
peers map[peerid.ConnID]listener
done chan struct{}
mu sync.Mutex
@@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]*Listener),
peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}),
}
return m
@@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
return nil
}
listener, err := NewListener(m.wgIface, peerCfg)
listener, err := m.createListener(peerCfg)
if err != nil {
return err
}
m.peers[peerCfg.PeerConnID] = listener
m.peers[peerCfg.PeerConnID] = listener
go m.waitForTraffic(listener, peerCfg.PeerConnID)
return nil
}
func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) {
if !m.wgIface.IsUserspaceBind() {
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is only used on Windows and JS platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// BindListener bypasses this by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider)
if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
}
return NewBindListener(m.wgIface, provider.GetBind(), peerCfg)
}
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -82,8 +115,8 @@ func (m *Manager) Close() {
}
}
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
listener.ReadPackets()
func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
l.ReadPackets()
m.mu.Lock()
if _, ok := m.peers[peerConnID]; !ok {

View File

@@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
@@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error {
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
// Add this method to the Manager struct
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
func (m MocWGIface) IsUserspaceBind() bool {
return false
}
func (m MocWGIface) Address() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
}
}
// GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock()
defer m.mu.Unlock()
listener, exists := m.peers[peerConnID]
return listener, exists
l, exists := m.peers[peerConnID]
return l, exists
}
func TestManager_MonitorPeerActivity(t *testing.T) {
@@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
t.Fatalf("peer listener not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
// Get the UDP listener's address for triggering
udpListener, ok := listener.(*UDPListener)
if !ok {
t.Fatalf("expected UDPListener")
}
if err := trigger(udpListener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
@@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) {
t.Fatalf("failed to monitor peer activity: %v", err)
}
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID)
udpListener, _ := listener.(*UDPListener)
addr := udpListener.conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
@@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("peer listener for peer1 not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
udpListener1, _ := listener.(*UDPListener)
if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
@@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("peer listener for peer2 not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
udpListener2, _ := listener.(*UDPListener)
if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}

View File

@@ -7,6 +7,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/monotime"
)
@@ -14,5 +15,6 @@ type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
LastActivities() map[string]monotime.Time
}

View File

@@ -107,14 +107,10 @@ type Conn struct {
wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard Guard
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
wg sync.WaitGroup
// used for replace the new guard with the old one in a thread-safe way
guardCtxCancel context.CancelFunc
wgGuard sync.WaitGroup
// debug purpose
dumpState *stateDump
@@ -200,34 +196,14 @@ func (conn *Conn) Open(engineCtx context.Context) error {
}
conn.wg.Add(1)
conn.wgGuard.Add(1)
guardCtx, cancel := context.WithCancel(conn.ctx)
conn.guardCtxCancel = cancel
go func() {
defer conn.wg.Done()
defer conn.wgGuard.Done()
defer cancel()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx)
conn.guard.Start(guardCtx, conn.onGuardEvent)
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
// both peer send offer
if err := conn.handshaker.SendOffer(); err != nil {
switch err {
case ErrPeerNotAvailable:
conn.Log.Warnf("failed to deliver offer to peer. Peer is not available")
case ErrSignalNotSupportDeliveryCheck:
conn.Log.Infof("signal delivery check is not supported, switch guard to retry mode")
conn.switchGuard()
default:
conn.Log.Errorf("failed to deliver offer to peer: %v", err)
conn.guard.FailedToSendOffer()
}
}
conn.opened = true
return nil
}
@@ -580,17 +556,7 @@ func (conn *Conn) onRelayDisconnected() {
func (conn *Conn) onGuardEvent() {
conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil {
switch err {
case ErrPeerNotAvailable:
conn.Log.Warnf("failed to deliver offer to peer. Peer is not available")
case ErrSignalNotSupportDeliveryCheck:
conn.Log.Infof("signal delivery check is not supported, switch guard to retry mode")
// must run on a separate goroutine to prevent deadlock while close the old guard
go conn.switchGuard()
default:
conn.Log.Errorf("failed to deliver offer to peer: %v", err)
conn.guard.FailedToSendOffer()
}
conn.Log.Errorf("failed to send offer: %v", err)
}
}
@@ -812,22 +778,6 @@ func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
return &key, nil
}
func (conn *Conn) switchGuard() {
if conn.guardCtxCancel == nil {
return
}
conn.guardCtxCancel()
conn.guardCtxCancel = nil
conn.wgGuard.Wait()
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.guard = guard.NewRetryGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -1,10 +0,0 @@
package peer
import "context"
type Guard interface {
Start(ctx context.Context, eventCallback func())
SetRelayedConnDisconnected()
SetICEConnDisconnected()
FailedToSendOffer()
}

View File

@@ -4,26 +4,20 @@ import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
)
const (
offerResendPeriod = 2 * time.Second
)
type isConnectedFunc func() bool
// Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues.
// Only the offer error will start the timer to resend offer periodically.
//
// Watch these events:
// - Relay client reconnected to home server
// - Signal server connection state changed
// - ICE connection disconnected
// - Relayed connection disconnected
// - ICE candidate changes
// - Failed to send offer to remote peer
type Guard struct {
log *log.Entry
isConnectedOnAllWay isConnectedFunc
@@ -31,7 +25,6 @@ type Guard struct {
srWatcher *SRWatcher
relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{}
offerError chan struct{}
}
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
@@ -42,7 +35,6 @@ func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Durati
srWatcher: srWatcher,
relayedConnDisconnected: make(chan struct{}, 1),
iCEConnDisconnected: make(chan struct{}, 1),
offerError: make(chan struct{}, 1),
}
}
@@ -65,54 +57,81 @@ func (g *Guard) SetICEConnDisconnected() {
}
}
func (g *Guard) FailedToSendOffer() {
select {
case g.offerError <- struct{}{}:
default:
}
}
// reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
offerResendTimer := time.NewTimer(0)
offerResendTimer.Stop()
defer offerResendTimer.Stop()
ticker := g.initialTicker(ctx)
defer ticker.Stop()
tickerChannel := ticker.C
for {
select {
case t := <-tickerChannel:
if t.IsZero() {
g.log.Infof("retry timed out, stop periodic offer sending")
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
tickerChannel = make(<-chan time.Time)
continue
}
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, retry connection")
offerResendTimer.Stop()
if !g.isConnectedOnAllWay() {
callback()
}
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, retry connection")
offerResendTimer.Stop()
if !g.isConnectedOnAllWay() {
callback()
}
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-srReconnectedChan:
g.log.Debugf("has network changes, retry connection")
offerResendTimer.Stop()
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.offerError:
g.log.Debugf("failed to send offer, reset reconnection ticker")
offerResendTimer.Reset(offerResendPeriod)
continue
case <-offerResendTimer.C:
if !g.isConnectedOnAllWay() {
callback()
}
g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
// initialTicker give chance to the peer to establish the initial connection.
func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 3 * time.Second,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
ticker := backoff.NewTicker(bo)
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
return ticker
}

View File

@@ -1,139 +0,0 @@
package guard
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
)
// RetryGuard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues.
// Watch these events:
// - Relay client reconnected to home server
// - Signal server connection state changed
// - ICE connection disconnected
// - Relayed connection disconnected
// - ICE candidate changes
type RetryGuard struct {
log *log.Entry
isConnectedOnAllWay isConnectedFunc
timeout time.Duration
srWatcher *SRWatcher
relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{}
}
func (g *RetryGuard) FailedToSendOffer() {
log.Errorf("FailedToSendOffer is not implemented in GuardRetry")
}
func NewRetryGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *RetryGuard {
return &RetryGuard{
log: log,
isConnectedOnAllWay: isConnectedFn,
timeout: timeout,
srWatcher: srWatcher,
relayedConnDisconnected: make(chan struct{}, 1),
iCEConnDisconnected: make(chan struct{}, 1),
}
}
func (g *RetryGuard) Start(ctx context.Context, eventCallback func()) {
g.log.Infof("starting guard for reconnection with MaxInterval: %s", g.timeout)
g.reconnectLoopWithRetry(ctx, eventCallback)
}
func (g *RetryGuard) SetRelayedConnDisconnected() {
select {
case g.relayedConnDisconnected <- struct{}{}:
default:
}
}
func (g *RetryGuard) SetICEConnDisconnected() {
select {
case g.iCEConnDisconnected <- struct{}{}:
default:
}
}
// reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *RetryGuard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.initialTicker(ctx)
defer ticker.Stop()
tickerChannel := ticker.C
for {
select {
case t := <-tickerChannel:
if t.IsZero() {
g.log.Infof("retry timed out, stop periodic offer sending")
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
tickerChannel = make(<-chan time.Time)
continue
}
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
// initialTicker give chance to the peer to establish the initial connection.
func (g *RetryGuard) initialTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 3 * time.Second,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}
func (g *RetryGuard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
ticker := backoff.NewTicker(bo)
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
return ticker
}

View File

@@ -1,9 +1,6 @@
package peer
import (
"errors"
"sync/atomic"
"github.com/pion/ice/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -12,16 +9,9 @@ import (
sProto "github.com/netbirdio/netbird/shared/signal/proto"
)
var (
ErrPeerNotAvailable = signal.ErrPeerNotAvailable
ErrSignalNotSupportDeliveryCheck = errors.New("the signal client does not support SendWithDeliveryCheck")
)
type Signaler struct {
signal signal.Client
wgPrivateKey wgtypes.Key
deliveryCheckNotSupported atomic.Bool
}
func NewSignaler(signal signal.Client, wgPrivateKey wgtypes.Key) *Signaler {
@@ -77,21 +67,10 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
return err
}
if s.deliveryCheckNotSupported.Load() {
return s.signal.Send(msg)
if err = s.signal.Send(msg); err != nil {
return err
}
if err = s.signal.SendWithDeliveryCheck(msg); err != nil {
switch {
case errors.Is(err, signal.ErrPeerNotAvailable):
return ErrPeerNotAvailable
case errors.Is(err, signal.ErrUnimplementedMethod):
s.handleUnimplementedMethod(msg)
return ErrSignalNotSupportDeliveryCheck
default:
return err
}
}
return nil
}
@@ -104,15 +83,3 @@ func (s *Signaler) SignalIdle(remoteKey string) error {
},
})
}
func (s *Signaler) handleUnimplementedMethod(msg *sProto.Message) {
// print out the warning only once
if !s.deliveryCheckNotSupported.Load() {
log.Warnf("signal client does not support delivery check, falling back to Send method and resend")
}
s.deliveryCheckNotSupported.Store(true)
if err := s.signal.Send(msg); err != nil {
log.Warnf("failed to send signal msg to remote peer: %v", err)
}
}

View File

@@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
WgPort: iface.DefaultWgPort,
}
if _, err := config.apply(input); err != nil {

View File

@@ -5,11 +5,14 @@ import (
"errors"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/util"
)
@@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) {
}
}
func TestNewProfileDefaults(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
})
require.NoError(t, err, "should create new config")
assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default")
assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default")
assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated")
assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated")
assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default")
assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820")
assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default")
assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default")
assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set")
assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set")
assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults")
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS")
assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS")
}
}
func TestWireguardPortZeroExplicit(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
// Create a new profile with explicit port 0 (random port)
explicitZero := 0
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
WireguardPort: &explicitZero,
})
require.NoError(t, err, "should create config with explicit port 0")
assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user")
// Verify it persists
readConfig, err := GetConfig(configPath)
require.NoError(t, err)
assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file")
}
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
tests := []struct {
name string
wireguardPort *int
expectedPort int
description string
}{
{
name: "no port specified uses default",
wireguardPort: nil,
expectedPort: iface.DefaultWgPort,
description: "When user doesn't specify port, default to 51820",
},
{
name: "explicit zero for random port",
wireguardPort: func() *int { v := 0; return &v }(),
expectedPort: 0,
description: "When user explicitly sets 0, use 0 for random port",
},
{
name: "explicit custom port",
wireguardPort: func() *int { v := 52000; return &v }(),
expectedPort: 52000,
description: "When user sets custom port, use that port",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
WireguardPort: tt.wireguardPort,
})
require.NoError(t, err, tt.description)
assert.Equal(t, tt.expectedPort, config.WgPort, tt.description)
})
}
}
func TestUpdateOldManagementURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -0,0 +1,59 @@
package winregistry
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows/registry"
)
var (
advapi = syscall.NewLazyDLL("advapi32.dll")
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
)
const (
// Registry key options
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
// Registry disposition values
regCreatedNewKey = 0x1
regOpenedExistingKey = 0x2
)
// CreateVolatileKey creates a volatile registry key named path under open key root.
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
// The access parameter specifies the access rights for the key to be created.
//
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
// This provides automatic cleanup without requiring manual registry maintenance.
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
pathPtr, err := syscall.UTF16PtrFromString(path)
if err != nil {
return 0, false, err
}
var (
handle syscall.Handle
disposition uint32
)
ret, _, _ := regCreateKeyExW.Call(
uintptr(root),
uintptr(unsafe.Pointer(pathPtr)),
0, // reserved
0, // class
uintptr(regOptionVolatile), // options - volatile key
uintptr(access), // desired access
0, // security attributes
uintptr(unsafe.Pointer(&handle)),
uintptr(unsafe.Pointer(&disposition)),
)
if ret != 0 {
return 0, false, syscall.Errno(ret)
}
return registry.Key(handle), disposition == regOpenedExistingKey, nil
}

View File

@@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.CustomDNSAddress = []byte{}
}
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
if msg.DnsRouteInterval != nil {
interval := msg.DnsRouteInterval.AsDuration()
config.DNSRouteInterval = &interval
}
config.RosenpassEnabled = msg.RosenpassEnabled
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect

View File

@@ -345,7 +345,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""), nil)
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)

View File

@@ -0,0 +1,298 @@
package server
import (
"context"
"os/user"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
)
// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config.
// This test uses reflection to detect when new fields are added but not handled in SetConfig.
func TestSetConfig_AllFieldsSaved(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origDefaultConfigPath := profilemanager.DefaultConfigPath
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.ConfigDirOverride = tempDir
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
profilemanager.DefaultConfigPath = origDefaultConfigPath
profilemanager.ConfigDirOverride = ""
})
currUser, err := user.Current()
require.NoError(t, err)
profName := "test-profile"
ic := profilemanager.ConfigInput{
ConfigPath: filepath.Join(tempDir, profName+".json"),
ManagementURL: "https://api.netbird.io:443",
}
_, err = profilemanager.UpdateOrCreateConfig(ic)
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
})
require.NoError(t, err)
ctx := context.Background()
s := New(ctx, "console", "", false, false)
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
disableAutoConnect := true
networkMonitor := true
disableClientRoutes := true
disableServerRoutes := true
disableDNS := true
disableFirewall := true
blockLANAccess := true
disableNotifications := true
lazyConnectionEnabled := true
blockInbound := true
mtu := int64(1280)
req := &proto.SetConfigRequest{
ProfileName: profName,
Username: currUser.Username,
ManagementUrl: "https://new-api.netbird.io:443",
AdminURL: "https://new-admin.netbird.io",
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
DisableAutoConnect: &disableAutoConnect,
NetworkMonitor: &networkMonitor,
DisableClientRoutes: &disableClientRoutes,
DisableServerRoutes: &disableServerRoutes,
DisableDns: &disableDNS,
DisableFirewall: &disableFirewall,
BlockLanAccess: &blockLANAccess,
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
DnsLabels: []string{"label1", "label2"},
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
}
_, err = s.SetConfig(ctx, req)
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()
require.NoError(t, err)
cfg, err := profilemanager.GetConfig(cfgPath)
require.NoError(t, err)
require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String())
require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String())
require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled)
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect)
require.NotNil(t, cfg.NetworkMonitor)
require.Equal(t, networkMonitor, *cfg.NetworkMonitor)
require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes)
require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes)
require.Equal(t, disableDNS, cfg.DisableDNS)
require.Equal(t, disableFirewall, cfg.DisableFirewall)
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
require.NotNil(t, cfg.DisableNotifications)
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
require.Equal(t, blockInbound, cfg.BlockInbound)
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
// IFaceBlackList contains defaults + extras
require.Contains(t, cfg.IFaceBlackList, "eth1")
require.Contains(t, cfg.IFaceBlackList, "eth2")
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
require.Equal(t, uint16(mtu), cfg.MTU)
verifyAllFieldsCovered(t, req)
}
// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest.
// If a new field is added to SetConfigRequest, this function will fail the test,
// forcing the developer to update both the SetConfig handler and this test.
func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
t.Helper()
metadataFields := map[string]bool{
"state": true, // protobuf internal
"sizeCache": true, // protobuf internal
"unknownFields": true, // protobuf internal
"Username": true, // metadata
"ProfileName": true, // metadata
"CleanNATExternalIPs": true, // control flag for clearing
"CleanDNSLabels": true, // control flag for clearing
}
expectedFields := map[string]bool{
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
}
val := reflect.ValueOf(req).Elem()
typ := val.Type()
var unexpectedFields []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldName := field.Name
if metadataFields[fieldName] {
continue
}
if !expectedFields[fieldName] {
unexpectedFields = append(unexpectedFields, fieldName)
}
}
if len(unexpectedFields) > 0 {
t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields)
}
}
// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest.
// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq.
func TestCLIFlags_MappedToSetConfig(t *testing.T) {
// Map of CLI flag names to their corresponding SetConfigRequest field names.
// This map must be updated when adding new config-related CLI flags.
flagToField := map[string]string{
"management-url": "ManagementUrl",
"admin-url": "AdminURL",
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",
"disable-auto-connect": "DisableAutoConnect",
"network-monitor": "NetworkMonitor",
"disable-client-routes": "DisableClientRoutes",
"disable-server-routes": "DisableServerRoutes",
"disable-dns": "DisableDns",
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
"extra-dns-labels": "DnsLabels",
"dns-router-interval": "DnsRouteInterval",
"mtu": "Mtu",
}
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
fieldsWithoutCLIFlags := map[string]bool{
"DisableNotifications": true, // Only settable via UI
}
// Get all SetConfigRequest fields to verify our map is complete.
req := &proto.SetConfigRequest{}
val := reflect.ValueOf(req).Elem()
typ := val.Type()
var unmappedFields []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldName := field.Name
// Skip protobuf internal fields and metadata fields.
if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" {
continue
}
if fieldName == "Username" || fieldName == "ProfileName" {
continue
}
if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" {
continue
}
// Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag.
mappedToCLI := false
for _, mappedField := range flagToField {
if mappedField == fieldName {
mappedToCLI = true
break
}
}
hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName]
if !mappedToCLI && !hasNoCLIFlag {
unmappedFields = append(unmappedFields, fieldName)
}
}
if len(unmappedFields) > 0 {
t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+
"Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+
"add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields)
}
t.Log("All SetConfigRequest fields are properly documented")
}

View File

@@ -205,15 +205,18 @@ func mapPeers(
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := "P2P"
connType := "-"
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if pbPeerState.Relayed {
connType = "Relayed"
if isPeerConnected {
connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
}
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {

View File

@@ -31,7 +31,6 @@ import (
"fyne.io/systray"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
}
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
err := open.Run(loginResp.VerificationURIComplete)
err := openURL(loginResp.VerificationURIComplete)
if err != nil {
log.Errorf("opening the verification uri in the browser failed: %v", err)
return err
@@ -1487,6 +1486,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
}
func openURL(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
var err error
switch runtime.GOOS {
case "windows":

View File

@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
return &tls.Config{
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
config := &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
return nil
},
}
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
if requiresCredSSP {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
} else {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS13
}
return config
}

View File

@@ -6,11 +6,13 @@ import (
"context"
"crypto/tls"
"encoding/asn1"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
@@ -19,18 +21,34 @@ const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
rdpDialTimeout = 15 * time.Second
GeneralErrorCode = 1
WSAETimedOut = 10060
WSAEConnRefused = 10061
WSAEConnAborted = 10053
WSAEConnReset = 10054
WSAEGenericError = 10050
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error []byte `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
Version int64 `asn1:"tag:0,explicit"`
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathErr struct {
ErrorCode int16 `asn1:"tag:0,explicit"`
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
WSALastError int16 `asn1:"tag:2,explicit,optional"`
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
}
type RDCleanPathProxy struct {
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
conn.rdpConn = rdpConn
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func errorToWSACode(err error) int16 {
if err == nil {
return WSAEGenericError
}
var netErr *net.OpError
if errors.As(err, &netErr) && netErr.Timeout() {
return WSAETimedOut
}
if errors.Is(err, context.DeadlineExceeded) {
return WSAETimedOut
}
if errors.Is(err, context.Canceled) {
return WSAEConnAborted
}
if errors.Is(err, io.EOF) {
return WSAEConnReset
}
return WSAEGenericError
}
func newWSAError(err error) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
WSALastError: errorToWSACode(err),
},
}
}
func newHTTPError(statusCode int16) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
HTTPStatusCode: statusCode,
},
}
}

View File

@@ -3,6 +3,7 @@
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
@@ -11,11 +12,17 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
protocolSSL = 0x00000001
protocolHybridEx = 0x00000008
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, "Unsupported version")
p.sendRDCleanPathError(conn, newHTTPError(400))
return
}
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
destination = pdu.Destination
}
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, "Connection failed")
p.sendRDCleanPathError(conn, newWSAError(err))
p.cleanupConnection(conn)
return
}
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
p.setupTLSConnection(conn, pdu)
}
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
const minResponseLength = 19
if len(x224Response) < minResponseLength {
return false, 0, false
}
// Per X.224 specification:
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
// x224Response[5] == 0xD0: X.224 Data TPDU code
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
return false, 0, false
}
if x224Response[11] == 0x02 {
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
return hasNLA, flags, true
}
return false, 0, false
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
tlsConfig := p.getTLSConfigWithValidation(conn)
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
if detected {
if requiresCredSSP {
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
} else {
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
}
} else {
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
}
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, "TLS handshake failed")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
X224ConnectionPDU: response[:n],
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
} else {
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
}
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
<-conn.ctx.Done()
log.Debug("TCP connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
pdu := RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: []byte(errorMsg),
}
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)

4
go.mod
View File

@@ -62,8 +62,8 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250930114722-bab681dd3a96
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible

8
go.sum
View File

@@ -503,12 +503,12 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250930114722-bab681dd3a96 h1:rPu2HVjtRZs/GloIz6h4tcNr/9Fq55SBBAlz6uOExt8=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250930114722-bab681dd3a96/go.mod h1:ap83I3Rs3SLND/58j9X+a3ixA9d9ztainwW01ExEw0w=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=

View File

@@ -49,6 +49,7 @@ services:
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.routers.netbird-signal.service=netbird-signal
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c

View File

@@ -682,17 +682,6 @@ renderManagementJson() {
"URI": "stun:$NETBIRD_DOMAIN:3478"
}
],
"TURNConfig": {
"Turns": [
{
"Proto": "udp",
"URI": "turn:$NETBIRD_DOMAIN:3478",
"Username": "$TURN_USER",
"Password": "$TURN_PASSWORD"
}
],
"TimeBasedCredentials": false
},
"Relay": {
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
"CredentialsTTL": "24h",

View File

@@ -26,9 +26,11 @@ type mockHTTPClient struct {
}
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
if req.Body != nil {
body, err := io.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
}
}
return &http.Response{
StatusCode: c.code,

View File

@@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
APIToken: config.ExtraConfig["ApiToken"],
}
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
case "pocketid":
pocketidConfig := PocketIdClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
}
return NewPocketIdManager(pocketidConfig, appMetrics)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
}

View File

@@ -0,0 +1,384 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
)
type PocketIdManager struct {
managementEndpoint string
apiToken string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
type pocketIdCustomClaimDto struct {
Key string `json:"key"`
Value string `json:"value"`
}
type pocketIdUserDto struct {
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
Disabled bool `json:"disabled"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
FirstName string `json:"firstName"`
ID string `json:"id"`
IsAdmin bool `json:"isAdmin"`
LastName string `json:"lastName"`
LdapID string `json:"ldapId"`
Locale string `json:"locale"`
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
Username string `json:"username"`
}
type pocketIdUserCreateDto struct {
Disabled bool `json:"disabled,omitempty"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
FirstName string `json:"firstName"`
IsAdmin bool `json:"isAdmin,omitempty"`
LastName string `json:"lastName,omitempty"`
Locale string `json:"locale,omitempty"`
Username string `json:"username"`
}
type pocketIdPaginatedUserDto struct {
Data []pocketIdUserDto `json:"data"`
Pagination pocketIdPaginationDto `json:"pagination"`
}
type pocketIdPaginationDto struct {
CurrentPage int `json:"currentPage"`
ItemsPerPage int `json:"itemsPerPage"`
TotalItems int `json:"totalItems"`
TotalPages int `json:"totalPages"`
}
func (p *pocketIdUserDto) userData() *UserData {
return &UserData{
Email: p.Email,
Name: p.DisplayName,
ID: p.ID,
AppMetadata: AppMetadata{},
}
}
type pocketIdUserGroupDto struct {
CreatedAt string `json:"createdAt"`
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
FriendlyName string `json:"friendlyName"`
ID string `json:"id"`
LdapID string `json:"ldapId"`
Name string `json:"name"`
}
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ManagementEndpoint == "" {
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
}
if config.APIToken == "" {
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
}
credentials := &PocketIdCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &PocketIdManager{
managementEndpoint: config.ManagementEndpoint,
apiToken: config.APIToken,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
if !slices.Contains(MethodsWithBody, method) && body != "" {
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
}
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
if query != nil {
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
}
var req *http.Request
var err error
if body != "" {
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
} else {
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
}
if err != nil {
return nil, err
}
req.Header.Add("X-API-KEY", p.apiToken)
if body != "" {
req.Header.Add("content-type", "application/json")
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
}
resp, err := p.httpClient.Do(req)
if err != nil {
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// getAllUsersPaginated fetches all users from PocketID API using pagination
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
var allUsers []pocketIdUserDto
currentPage := 1
for {
params := url.Values{}
// Copy existing search parameters
for key, values := range searchParams {
params[key] = values
}
params.Set("pagination[limit]", "100")
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
body, err := p.request(ctx, http.MethodGet, "users", &params, "")
if err != nil {
return nil, err
}
var profiles pocketIdPaginatedUserDto
err = p.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
allUsers = append(allUsers, profiles.Data...)
// Check if we've reached the last page
if currentPage >= profiles.Pagination.TotalPages {
break
}
currentPage++
}
return allUsers, nil
}
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var user pocketIdUserDto
err = p.helper.Unmarshal(body, &user)
if err != nil {
return nil, err
}
userData := user.userData()
userData.AppMetadata = appMetadata
return userData, nil
}
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
// Get all users using pagination
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, profile := range allUsers {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountId
users = append(users, userData)
}
return users, nil
}
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
// Get all users using pagination
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range allUsers {
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
}
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
firstLast := strings.Split(name, " ")
createUser := pocketIdUserCreateDto{
Disabled: false,
DisplayName: name,
Email: email,
FirstName: firstLast[0],
LastName: firstLast[1],
Username: firstLast[0] + "." + firstLast[1],
}
payload, err := p.helper.Marshal(createUser)
if err != nil {
return nil, err
}
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
if err != nil {
return nil, err
}
var newUser pocketIdUserDto
err = p.helper.Unmarshal(body, &newUser)
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountCreateUser()
}
var pending bool = true
ret := &UserData{
Email: email,
Name: name,
ID: newUser.ID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pending,
WTInvitedBy: invitedByEmail,
},
}
return ret, nil
}
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
params := url.Values{
// This value a
"search": []string{email},
}
body, err := p.request(ctx, http.MethodGet, "users", &params, "")
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profiles struct{ data []pocketIdUserDto }
err = p.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.data {
users = append(users, profile.userData())
}
return users, nil
}
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
if err != nil {
return err
}
return nil
}
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
if err != nil {
return err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
var _ Manager = (*PocketIdManager)(nil)
type PocketIdClientConfig struct {
APIToken string
ManagementEndpoint string
}
type PocketIdCredentials struct {
clientConfig PocketIdClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}

View File

@@ -0,0 +1,138 @@
package idp
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewPocketIdManager(t *testing.T) {
type test struct {
name string
inputConfig PocketIdClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := PocketIdClientConfig{
APIToken: "api_token",
ManagementEndpoint: "http://localhost",
}
tests := []test{
{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
},
{
name: "Missing ManagementEndpoint",
inputConfig: PocketIdClientConfig{
APIToken: defaultTestConfig.APIToken,
ManagementEndpoint: "",
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
{
name: "Missing APIToken",
inputConfig: PocketIdClientConfig{
APIToken: "",
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
})
}
}
func TestPocketID_GetUserDataByID(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
md := AppMetadata{WTAccountID: "acc1"}
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
require.NoError(t, err)
assert.Equal(t, "u1", got.ID)
assert.Equal(t, "user1@example.com", got.Email)
assert.Equal(t, "User One", got.Name)
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
}
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
// Single page response with two users
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
users, err := mgr.GetAccount(context.Background(), "accX")
require.NoError(t, err)
require.Len(t, users, 2)
assert.Equal(t, "u1", users[0].ID)
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
assert.Equal(t, "u2", users[1].ID)
}
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
accounts, err := mgr.GetAllAccounts(context.Background())
require.NoError(t, err)
require.Len(t, accounts[UnsetAccountID], 2)
}
func TestPocketID_CreateUser(t *testing.T) {
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
require.NoError(t, err)
assert.Equal(t, "newid", ud.ID)
assert.Equal(t, "new@example.com", ud.Email)
assert.Equal(t, "New User", ud.Name)
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
assert.True(t, *ud.AppMetadata.WTPendingInvite)
}
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
}
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
// Same mock for both calls; returns OK with empty JSON
client := &mockHTTPClient{code: 200, resBody: `{}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
err = mgr.InviteUserByID(context.Background(), "u1")
require.NoError(t, err)
err = mgr.DeleteUser(context.Background(), "u1")
require.NoError(t, err)
}

View File

@@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil
}
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
return peer
}

View File

@@ -3,16 +3,16 @@ package integrated_validator
import (
"context"
"github.com/netbirdio/netbird/shared/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error

View File

@@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
var peer *nbpeer.Peer
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
if err != nil {
return err
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
if updateAccountPeers && userID != activity.SystemInitiator {
if userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -584,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
if err != nil {
updateAccountPeers = true
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
@@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if updateAccountPeers {
am.BufferUpdateAccountPeers(ctx, accountID)
}
am.BufferUpdateAccountPeers(ctx, accountID)
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
@@ -1527,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
if err != nil {
return false, err
}
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
}
// deletePeers deletes all specified peers and sends updates to the remote peers.
// Returns a slice of functions to save events after successful peer deletion.
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {

View File

@@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg) //
peerShouldReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()

View File

@@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then
NETBIRD_RELEASE=latest
fi
TAG_NAME=""
get_release() {
local RELEASE=$1
if [ "$RELEASE" = "latest" ]; then
@@ -38,17 +40,19 @@ get_release() {
local TAG="tags/${RELEASE}"
local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
fi
OUTPUT=""
if [ -n "$GITHUB_TOKEN" ]; then
curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}")
else
curl -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
OUTPUT=$(curl -s "${URL}")
fi
TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1)
echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+'
}
download_release_binary() {
VERSION=$(get_release "$NETBIRD_RELEASE")
echo "Using the following tag name for binary installation: ${TAG_NAME}"
BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download"
BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz"

View File

@@ -35,7 +35,6 @@ type Client interface {
WaitStreamConnected()
SendToStream(msg *proto.EncryptedMessage) error
Send(msg *proto.Message) error
SendWithDeliveryCheck(msg *proto.Message) error
SetOnReconnectedListener(func())
}

View File

@@ -199,7 +199,7 @@ func startSignal() (*grpc.Server, net.Listener) {
panic(err)
}
s := grpc.NewServer()
srv, err := server.NewServer(context.Background(), otel.Meter(""), nil)
srv, err := server.NewServer(context.Background(), otel.Meter(""))
if err != nil {
panic(err)
}

View File

@@ -2,7 +2,6 @@ package client
import (
"context"
"errors"
"fmt"
"io"
"sync"
@@ -24,11 +23,6 @@ import (
"github.com/netbirdio/netbird/util/wsproxy"
)
var (
ErrPeerNotAvailable = errors.New("peer not available")
ErrUnimplementedMethod = errors.New("the signal client does not support SendWithDeliveryCheck")
)
// ConnStateNotifier is a wrapper interface of the status recorder
type ConnStateNotifier interface {
MarkSignalDisconnected(error)
@@ -403,36 +397,6 @@ func (c *GrpcClient) Send(msg *proto.Message) error {
return err
}
func (c *GrpcClient) SendWithDeliveryCheck(msg *proto.Message) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
encryptedMessage, err := c.encryptMessage(msg)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(c.ctx, client.ConnectTimeout)
defer cancel()
_, err = c.realClient.SendWithDeliveryCheck(ctx, encryptedMessage)
if err != nil {
if st, ok := status.FromError(err); ok {
switch st.Code() {
case codes.NotFound:
return ErrPeerNotAvailable
case codes.Unimplemented:
return ErrUnimplementedMethod
default:
return fmt.Errorf("grpc error %s: %w", st.Code(), err)
}
}
return err // Not a gRPC status error
}
return err
}
// receive receives messages from other peers coming through the Signal Exchange
// and distributes them to worker threads for processing
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) error {

View File

@@ -16,7 +16,6 @@ type MockClient struct {
SendToStreamFunc func(msg *proto.EncryptedMessage) error
SendFunc func(msg *proto.Message) error
SetOnReconnectedListenerFunc func(f func())
SendWithDeliveryCheckFn func(msg *proto.Message) error
}
// SetOnReconnectedListener sets the function to be called when the client reconnects.
@@ -83,10 +82,3 @@ func (sm *MockClient) Send(msg *proto.Message) error {
}
return sm.SendFunc(msg)
}
func (sm *MockClient) SendWithDeliveryCheck(msg *proto.Message) error {
if sm.SendWithDeliveryCheckFn == nil {
return nil
}
return sm.SendWithDeliveryCheck(msg)
}

View File

@@ -10,7 +10,6 @@ import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/descriptorpb"
emptypb "google.golang.org/protobuf/types/known/emptypb"
reflect "reflect"
sync "sync"
)
@@ -440,79 +439,72 @@ var file_signalexchange_proto_rawDesc = []byte{
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74,
0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x56, 0x0a, 0x10, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79,
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1c, 0x0a, 0x09, 0x72,
0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09,
0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64,
0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0x63, 0x0a,
0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18,
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1c, 0x0a, 0x09, 0x72, 0x65,
0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72,
0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79,
0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, 0x04, 0x62, 0x6f,
0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x0a, 0x04, 0x74,
0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x2e,
0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61,
0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x61, 0x79,
0x6c, 0x6f, 0x61, 0x64, 0x12, 0x22, 0x0a, 0x0c, 0x77, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e,
0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x77, 0x67, 0x4c, 0x69,
0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x42,
0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09,
0x52, 0x0e, 0x6e, 0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
0x12, 0x28, 0x0a, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14,
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
0x4d, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x2c, 0x0a, 0x11, 0x66, 0x65,
0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x18,
0x06, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x11, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x53,
0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x12, 0x49, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65,
0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x07, 0x20, 0x01, 0x28,
0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66,
0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76,
0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52,
0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72,
0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64,
0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f,
0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x12, 0x0b,
0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c, 0x0a, 0x0a, 0x5f,
0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04, 0x4d, 0x6f, 0x64,
0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28,
0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, 0x42, 0x09,
0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52, 0x6f, 0x73,
0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x0a, 0x0f,
0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18,
0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, 0x02, 0x20,
0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65,
0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0x8e, 0x02, 0x0a, 0x0e, 0x53, 0x69, 0x67,
0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53,
0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68,
0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65,
0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x15, 0x53, 0x65, 0x6e,
0x64, 0x57, 0x69, 0x74, 0x68, 0x44, 0x65, 0x6c, 0x69, 0x76, 0x65, 0x72, 0x79, 0x43, 0x68, 0x65,
0x63, 0x6b, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61,
0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73,
0x73, 0x61, 0x67, 0x65, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x59,
0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12,
0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x56, 0x0a, 0x10, 0x45, 0x6e, 0x63, 0x72,
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03,
0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1c,
0x0a, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28,
0x09, 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04,
0x62, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
0x22, 0x63, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b,
0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1c, 0x0a,
0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x22, 0x0a, 0x0c, 0x77, 0x67, 0x4c, 0x69, 0x73,
0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x77,
0x67, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x6e,
0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20,
0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x12, 0x28, 0x0a, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28,
0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x2c, 0x0a,
0x11, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74,
0x65, 0x64, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x11, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72,
0x65, 0x73, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x12, 0x49, 0x0a, 0x0f, 0x72,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x07,
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f,
0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73,
0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70,
0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06,
0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44,
0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10,
0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c,
0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04,
0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01,
0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01,
0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f,
0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b,
0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73,
0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72,
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e,
0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c,
0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d,
0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e,
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a,
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -536,7 +528,6 @@ var file_signalexchange_proto_goTypes = []interface{}{
(*Body)(nil), // 3: signalexchange.Body
(*Mode)(nil), // 4: signalexchange.Mode
(*RosenpassConfig)(nil), // 5: signalexchange.RosenpassConfig
(*emptypb.Empty)(nil), // 6: google.protobuf.Empty
}
var file_signalexchange_proto_depIdxs = []int32{
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
@@ -544,13 +535,11 @@ var file_signalexchange_proto_depIdxs = []int32{
4, // 2: signalexchange.Body.mode:type_name -> signalexchange.Mode
5, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
1, // 4: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
1, // 5: signalexchange.SignalExchange.SendWithDeliveryCheck:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
6, // 8: signalexchange.SignalExchange.SendWithDeliveryCheck:output_type -> google.protobuf.Empty
1, // 9: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
7, // [7:10] is the sub-list for method output_type
4, // [4:7] is the sub-list for method input_type
1, // 5: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
6, // [6:8] is the sub-list for method output_type
4, // [4:6] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name

View File

@@ -1,7 +1,6 @@
syntax = "proto3";
import "google/protobuf/descriptor.proto";
import "google/protobuf/empty.proto";
option go_package = "/proto";
@@ -10,7 +9,6 @@ package signalexchange;
service SignalExchange {
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
rpc Send(EncryptedMessage) returns (EncryptedMessage) {}
rpc SendWithDeliveryCheck(EncryptedMessage) returns (google.protobuf.Empty) {}
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
rpc ConnectStream(stream EncryptedMessage) returns (stream EncryptedMessage) {}
}

View File

@@ -7,7 +7,6 @@ import (
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
emptypb "google.golang.org/protobuf/types/known/emptypb"
)
// This is a compile-time assertion to ensure that this generated file
@@ -21,7 +20,6 @@ const _ = grpc.SupportPackageIsVersion7
type SignalExchangeClient interface {
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
Send(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
SendWithDeliveryCheck(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*emptypb.Empty, error)
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
ConnectStream(ctx context.Context, opts ...grpc.CallOption) (SignalExchange_ConnectStreamClient, error)
}
@@ -43,15 +41,6 @@ func (c *signalExchangeClient) Send(ctx context.Context, in *EncryptedMessage, o
return out, nil
}
func (c *signalExchangeClient) SendWithDeliveryCheck(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*emptypb.Empty, error) {
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, "/signalexchange.SignalExchange/SendWithDeliveryCheck", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *signalExchangeClient) ConnectStream(ctx context.Context, opts ...grpc.CallOption) (SignalExchange_ConnectStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &SignalExchange_ServiceDesc.Streams[0], "/signalexchange.SignalExchange/ConnectStream", opts...)
if err != nil {
@@ -89,7 +78,6 @@ func (x *signalExchangeConnectStreamClient) Recv() (*EncryptedMessage, error) {
type SignalExchangeServer interface {
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
Send(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
SendWithDeliveryCheck(context.Context, *EncryptedMessage) (*emptypb.Empty, error)
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
ConnectStream(SignalExchange_ConnectStreamServer) error
mustEmbedUnimplementedSignalExchangeServer()
@@ -102,9 +90,6 @@ type UnimplementedSignalExchangeServer struct {
func (UnimplementedSignalExchangeServer) Send(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method Send not implemented")
}
func (UnimplementedSignalExchangeServer) SendWithDeliveryCheck(context.Context, *EncryptedMessage) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SendWithDeliveryCheck not implemented")
}
func (UnimplementedSignalExchangeServer) ConnectStream(SignalExchange_ConnectStreamServer) error {
return status.Errorf(codes.Unimplemented, "method ConnectStream not implemented")
}
@@ -139,24 +124,6 @@ func _SignalExchange_Send_Handler(srv interface{}, ctx context.Context, dec func
return interceptor(ctx, in, info, handler)
}
func _SignalExchange_SendWithDeliveryCheck_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SignalExchangeServer).SendWithDeliveryCheck(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/signalexchange.SignalExchange/SendWithDeliveryCheck",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SignalExchangeServer).SendWithDeliveryCheck(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _SignalExchange_ConnectStream_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(SignalExchangeServer).ConnectStream(&signalExchangeConnectStreamServer{stream})
}
@@ -194,10 +161,6 @@ var SignalExchange_ServiceDesc = grpc.ServiceDesc{
MethodName: "Send",
Handler: _SignalExchange_Send_Handler,
},
{
MethodName: "SendWithDeliveryCheck",
Handler: _SignalExchange_SendWithDeliveryCheck_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -2,7 +2,6 @@ package cmd
import (
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
@@ -10,19 +9,6 @@ import (
"github.com/spf13/pflag"
)
func EnvDisableSendWithDeliveryCheck() bool {
envVar := "NB_DISABLE_SEND_WITH_DELIVERY_CHECK"
value, present := os.LookupEnv(envVar)
if !present {
return false
}
if parsed, err := strconv.ParseBool(value); err == nil {
return parsed
}
return false
}
// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_
func setFlagsFromEnvVars(cmd *cobra.Command) {
flags := cmd.PersistentFlags()

View File

@@ -10,7 +10,6 @@ import (
"net/http"
// nolint:gosec
_ "net/http/pprof"
"os"
"time"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
@@ -93,9 +92,7 @@ var (
RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
if os.Getenv("NB_PPROF_ENABLE") == "true" {
startPprof()
}
startPprof()
opts, certManager, err := getTLSConfigurations()
if err != nil {
@@ -117,11 +114,7 @@ var (
}
}()
optsSignal := &server.Options{
DisableSendWithDeliveryCheck: EnvDisableSendWithDeliveryCheck(),
}
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter, optsSignal)
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter)
if err != nil {
return fmt.Errorf("creating signal server: %v", err)
}
@@ -157,7 +150,7 @@ var (
serveHTTP(httpListener, grpcRootHandler)
}
if signalPort != legacyGRPCPort && os.Getenv("NB_DISABLE_FALLBACK_GRPC") != "true" {
if signalPort != legacyGRPCPort {
// The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal
// are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000.
compatListener, err = serveGRPC(grpcServer, legacyGRPCPort)

View File

@@ -14,7 +14,6 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
gproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/netbirdio/signal-dispatcher/dispatcher"
@@ -23,8 +22,6 @@ import (
"github.com/netbirdio/netbird/signal/peer"
)
var ErrPeerNotConnected = errors.New("peer not connected")
const (
labelType = "type"
labelTypeError = "error"
@@ -52,32 +49,20 @@ var (
ErrPeerRegisteredAgain = errors.New("peer registered again")
)
type Options struct {
// Disable SendWithDeliveryCheck method
DisableSendWithDeliveryCheck bool
}
// Server an instance of a Signal server
type Server struct {
metrics *metrics.AppMetrics
disableSendWithDeliveryCheck bool
registry *peer.Registry
proto.UnimplementedSignalExchangeServer
dispatcher *dispatcher.Dispatcher
metrics *metrics.AppMetrics
successHeader metadata.MD
sendTimeout time.Duration
directSendDisabled bool
sendTimeout time.Duration
}
// NewServer creates a new Signal server
func NewServer(ctx context.Context, meter metric.Meter, opts *Options) (*Server, error) {
if opts == nil {
opts = &Options{}
}
func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter)
if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err)
@@ -96,21 +81,11 @@ func NewServer(ctx context.Context, meter metric.Meter, opts *Options) (*Server,
}
s := &Server{
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
sendTimeout: sTimeout,
disableSendWithDeliveryCheck: opts.DisableSendWithDeliveryCheck,
}
if directSendDisabled := os.Getenv("NB_SIGNAL_DIRECT_SEND_DISABLED"); directSendDisabled == "true" {
s.directSendDisabled = true
log.Warn("direct send to connected peers is disabled")
}
if opts.DisableSendWithDeliveryCheck {
log.Warn("SendWithDeliveryCheck method is disabled")
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
sendTimeout: sTimeout,
}
return s, nil
@@ -120,51 +95,12 @@ func NewServer(ctx context.Context, meter metric.Meter, opts *Options) (*Server,
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
if _, found := s.registry.Get(msg.RemoteKey); found && !s.directSendDisabled {
_ = s.forwardMessageToPeer(ctx, msg)
if _, found := s.registry.Get(msg.RemoteKey); found {
s.forwardMessageToPeer(ctx, msg)
return &proto.EncryptedMessage{}, nil
}
if _, err := s.dispatcher.SendMessage(ctx, msg, false); err != nil {
log.Errorf("error sending message via dispatcher: %v", err)
}
return &proto.EncryptedMessage{}, nil
}
// SendWithDeliveryCheck forwards a message to the signal peer with error handling
// When the remote peer is not connected it returns codes.NotFound error, otherwise it returns other types of errors
// that can be retried. In case codes.NotFound is returned the caller should not retry sending the message. The remote
// peer should send a new offer to re-establish the connection when it comes back online.
// Todo: double check the thread safe registry management. When both peer come online at the same time then both peers
// might not be registered yet when the first message is sent.
func (s *Server) SendWithDeliveryCheck(ctx context.Context, msg *proto.EncryptedMessage) (*emptypb.Empty, error) {
if s.disableSendWithDeliveryCheck {
log.Tracef("SendWithDeliveryCheck is disabled")
return nil, status.Errorf(codes.Unimplemented, "SendWithDeliveryCheck method is disabled")
}
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
if _, found := s.registry.Get(msg.RemoteKey); found && !s.directSendDisabled {
if err := s.forwardMessageToPeer(ctx, msg); err != nil {
if errors.Is(err, ErrPeerNotConnected) {
log.Tracef("remote peer [%s] not connected", msg.RemoteKey)
return nil, status.Errorf(codes.NotFound, "remote peer not connected")
}
log.Errorf("error sending message with delivery check to peer [%s]: %v", msg.RemoteKey, err)
return nil, status.Errorf(codes.Internal, "error forwarding message to peer: %v", err)
}
return &emptypb.Empty{}, nil
}
if _, err := s.dispatcher.SendMessage(ctx, msg, true); err != nil {
if errors.Is(err, dispatcher.ErrPeerNotConnected) {
log.Tracef("remote peer [%s] doesn't have a listener", msg.RemoteKey)
return nil, status.Errorf(codes.NotFound, "remote peer not connected")
}
return nil, err
}
return &emptypb.Empty{}, nil
return s.dispatcher.SendMessage(ctx, msg)
}
// ConnectStream connects to the exchange stream
@@ -172,7 +108,6 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
ctx, cancel := context.WithCancel(context.Background())
p, err := s.RegisterPeer(stream, cancel)
if err != nil {
log.Errorf("error registering peer: %v", err)
return err
}
@@ -182,7 +117,6 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
err = stream.SendHeader(s.successHeader)
if err != nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader)))
log.Errorf("error sending registration header to peer [%s] [streamID %d] : %v", p.Id, p.StreamID, err)
return err
}
@@ -207,7 +141,7 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, c
p := peer.NewPeer(id[0], stream, cancel)
if err := s.registry.Register(p); err != nil {
return nil, fmt.Errorf("error adding peer to registry peer: %w", err)
return nil, err
}
err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
if err != nil {
@@ -224,7 +158,7 @@ func (s *Server) DeregisterPeer(p *peer.Peer) {
s.registry.Deregister(p)
}
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) error {
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
log.Tracef("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
getRegistrationStart := time.Now()
@@ -236,7 +170,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
// todo respond to the sender?
return ErrPeerNotConnected
return
}
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
@@ -257,7 +191,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
if err != nil {
log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
return fmt.Errorf("error sending message to peer: %v", err)
return
}
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(ctx, 1)
@@ -266,13 +200,10 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
case <-dstPeer.Stream.Context().Done():
log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected)))
return fmt.Errorf("destination peer disconnected")
case <-time.After(s.sendTimeout):
dstPeer.Cancel() // cancel the peer context to trigger deregistration
log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout)))
return fmt.Errorf("sending message to peer timeout")
}
return nil
}

View File

@@ -1,9 +1,13 @@
package version
import "golang.org/x/sys/windows/registry"
import (
"golang.org/x/sys/windows/registry"
"runtime"
)
const (
urlWinExe = "https://pkgs.netbird.io/windows/x64"
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
)
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
@@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
// DownloadUrl return with the proper download link
func DownloadUrl() string {
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
if err == nil {
return urlWinExe
} else {
if err != nil {
return downloadURL
}
url := urlWinExe
if runtime.GOARCH == "arm64" {
url = urlWinExeArm
}
return url
}