diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..d53c5f06b 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/cmd/login.go b/client/cmd/login.go index 3ac211805..40b55f858 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "os/exec" "os/user" "runtime" "strings" @@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := open.Run(verificationURIComplete); err != nil { + if err := openBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } +// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func openBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} + // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index ed8a7403b..d78372c9e 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return "" } - // Include action in the ipset name to prevent squashing rules with different actions actionSuffix := "" if action == firewall.ActionDrop { actionSuffix = "-drop" diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 54fbb002c..6aff53b92 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -29,7 +29,8 @@ func Backoff(ctx context.Context) backoff.BackOff { // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) - if tlsEnabled { + // for js, the outer websocket layer takes care of tls + if tlsEnabled && runtime.GOOS != "js" { certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) @@ -37,9 +38,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone } transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ - // for js, outer websocket layer takes care of tls verification via WithCustomDialer - InsecureSkipVerify: runtime.GOOS == "js", - RootCAs: certPool, + RootCAs: certPool, })) } diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index 84afc38f5..96b286175 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, return nil } +func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + // Get the existing peer to preserve its allowed IPs + existingPeer, err := c.getPeer(c.deviceName, peerKey) + if err != nil { + return fmt.Errorf("get peer: %w", err) + } + + removePeerCfg := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil { + return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err) + } + + //Re-add the peer without the endpoint but same AllowedIPs + reAddPeerCfg := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + AllowedIPs: existingPeer.AllowedIPs, + ReplaceAllowedIPs: true, + } + + if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil { + return fmt.Errorf( + `error re-adding peer %s to interface %s with allowed IPs %v: %w`, + peerKey, c.deviceName, existingPeer.AllowedIPs, err, + ) + } + + return nil +} + func (c *KernelConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index f744e0127..bc875b73c 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, return nil } +func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return fmt.Errorf("parse peer key: %w", err) + } + + ipcStr, err := c.device.IpcGet() + if err != nil { + return fmt.Errorf("get IPC config: %w", err) + } + + // Parse current status to get allowed IPs for the peer + stats, err := parseStatus(c.deviceName, ipcStr) + if err != nil { + return fmt.Errorf("parse IPC config: %w", err) + } + + var allowedIPs []net.IPNet + found := false + for _, peer := range stats.Peers { + if peer.PublicKey == peerKey { + allowedIPs = peer.AllowedIPs + found = true + break + } + } + if !found { + return fmt.Errorf("peer %s not found", peerKey) + } + + // remove the peer from the WireGuard configuration + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil { + return fmt.Errorf("failed to remove peer: %s", ipcErr) + } + + // Build the peer config + peer = wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: true, + AllowedIPs: allowedIPs, + } + + config = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil { + return fmt.Errorf("remove endpoint address: %w", err) + } + + return nil +} + func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { diff --git a/client/iface/device.go b/client/iface/device.go index 921f0ea98..c0c829825 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -23,4 +23,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index a731684cc..48346fc0f 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net { return nil } +// GetICEBind returns the ICEBind instance +func (t *WGTunDevice) GetICEBind() EndpointManager { + return t.iceBind +} + func routesToString(routes []string) string { return strings.Join(routes, ";") } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 390efe088..acd5f6f11 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 96e4c8bcf..f96edf992 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index cdac43a53..2a836f846 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns nil for kernel mode devices +func (t *TunKernelDevice) GetICEBind() EndpointManager { + return nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index e37321b68..40d8fdac8 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -21,6 +21,7 @@ type Bind interface { conn.Bind GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) ActivityRecorder() *bind.ActivityRecorder + EndpointManager } type TunNetstackDevice struct { @@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device { func (t *TunNetstackDevice) GetNet() *netstack.Net { return t.net } + +// GetICEBind returns the bind instance +func (t *TunNetstackDevice) GetICEBind() EndpointManager { + return t.bind +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4cdd70a32..24654fc03 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error { func (t *USPDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *USPDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f1023bc0a..96350df8a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/endpoint_manager.go b/client/iface/device/endpoint_manager.go new file mode 100644 index 000000000..b53888baa --- /dev/null +++ b/client/iface/device/endpoint_manager.go @@ -0,0 +1,13 @@ +package device + +import ( + "net" + "net/netip" +) + +// EndpointManager manages fake IP to connection mappings for userspace bind implementations. +// Implemented by bind.ICEBind and bind.RelayBindJS. +type EndpointManager interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) +} diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index 1f40b0d46..db53d9c3a 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -21,4 +21,5 @@ type WGConfigurer interface { GetStats() (map[string]configurer.WGStats, error) FullStats() (*configurer.Stats, error) LastActivities() map[string]monotime.Time + RemoveEndpointAddress(peerKey string) error } diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 4649b8b97..cdfcea48d 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,4 +21,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/iface.go b/client/iface/iface.go index 609572561..07235a995 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetBind returns the EndpointManager userspace bind mode. +func (w *WGIface) GetBind() device.EndpointManager { + w.mu.Lock() + defer w.mu.Unlock() + + if w.tun == nil { + return nil + } + return w.tun.GetICEBind() +} + // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind func (w *WGIface) IsUserspaceBind() bool { return w.userspaceBind @@ -148,6 +159,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } +func (w *WGIface) RemoveEndpointAddress(peerKey string) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } + + log.Debugf("Removing endpoint address: %s", peerKey) + return w.configurer.RemoveEndpointAddress(peerKey) +} + // RemovePeer removes a Wireguard Peer from the interface iface func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 5ca950297..965decc73 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -29,11 +29,6 @@ type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } -type protoMatch struct { - ips map[string]int - policyID []byte -} - // DefaultManager uses firewall manager to handle type DefaultManager struct { firewall firewall.Manager @@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout } func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { - rules, squashedProtocols := d.squashAcceptRules(networkMap) + rules := networkMap.FirewallRules enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig.SshEnabled - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - enableSSH = enableSSH && !ok - } - if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { - enableSSH = enableSSH && !ok - } - // if TCP protocol rules not squashed and SSH enabled - // we add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP 22 port). + // If SSH enabled, add default firewall rule which accepts connection to any peer + // in the network by SSH (TCP port defined by ssh.DefaultSSHPort). if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", @@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID( return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } -// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type -// to all peers in the network map to one rule which just accepts that type of the traffic. -// -// NOTE: It will not squash two rules for same protocol if one covers all peers in the network, -// but other has port definitions or has drop policy. -func (d *DefaultManager) squashAcceptRules( - networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { - totalIPs := 0 - for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { - for range p.AllowedIps { - totalIPs++ - } - } - - in := map[mgmProto.RuleProtocol]*protoMatch{} - out := map[mgmProto.RuleProtocol]*protoMatch{} - - // trace which type of protocols was squashed - squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} - - // this function we use to do calculation, can we squash the rules by protocol or not. - // We summ amount of Peers IP for given protocol we found in original rules list. - // But we zeroed the IP's for protocol if: - // 1. Any of the rule has DROP action type. - // 2. Any of rule contains Port. - // - // We zeroed this to notify squash function that this protocol can't be squashed. - addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { - hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || - r.Port != "" || !portInfoEmpty(r.PortInfo) - - if hasPortRestrictions { - // Don't squash rules with port restrictions - protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} - return - } - - if _, ok := protocols[r.Protocol]; !ok { - protocols[r.Protocol] = &protoMatch{ - ips: map[string]int{}, - // store the first encountered PolicyID for this protocol - policyID: r.PolicyID, - } - } - - // special case, when we receive this all network IP address - // it means that rules for that protocol was already optimized on the - // management side - if r.PeerIP == "0.0.0.0" { - squashedRules = append(squashedRules, r) - squashedProtocols[r.Protocol] = struct{}{} - return - } - - ipset := protocols[r.Protocol].ips - - if _, ok := ipset[r.PeerIP]; ok { - return - } - ipset[r.PeerIP] = i - } - - for i, r := range networkMap.FirewallRules { - // calculate squash for different directions - if r.Direction == mgmProto.RuleDirection_IN { - addRuleToCalculationMap(i, r, in) - } else { - addRuleToCalculationMap(i, r, out) - } - } - - // order of squashing by protocol is important - // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.RuleProtocol{ - mgmProto.RuleProtocol_ALL, - mgmProto.RuleProtocol_ICMP, - mgmProto.RuleProtocol_TCP, - mgmProto.RuleProtocol_UDP, - } - - squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { - for _, protocol := range protocolOrders { - match, ok := matches[protocol] - if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 { - // don't squash if : - // 1. Rules not cover all peers in the network - // 2. Rules cover only one peer in the network. - continue - } - - // add special rule 0.0.0.0 which allows all IP's in our firewall implementations - squashedRules = append(squashedRules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: direction, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: protocol, - PolicyID: match.policyID, - }) - squashedProtocols[protocol] = struct{}{} - - if protocol == mgmProto.RuleProtocol_ALL { - // if we have ALL traffic type squashed rule - // it allows all other type of traffic, so we can stop processing - break - } - } - } - - squash(in, mgmProto.RuleDirection_IN) - squash(out, mgmProto.RuleDirection_OUT) - - // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - return squashedRules, squashedProtocols - } - - if len(squashedRules) == 0 { - return networkMap.FirewallRules, squashedProtocols - } - - var rules []*mgmProto.FirewallRule - // filter out rules which was squashed from final list - // if we also have other not squashed rules. - for i, r := range networkMap.FirewallRules { - if _, ok := squashedProtocols[r.Protocol]; ok { - if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } - } - rules = append(rules, r) - } - - return append(rules, squashedRules...), squashedProtocols -} - // getRuleGroupingSelector takes all rule properties except IP address to build selector func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 664476ef4..daf4979ce 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) { }) } -func TestDefaultManagerSquashRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, 2, len(rules)) - - r := rules[0] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) - - r = rules[1] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) -} - -func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, len(networkMap.FirewallRules), len(rules)) -} - -func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) { - tests := []struct { - name string - rules []*mgmProto.FirewallRule - expectedCount int - description string - }{ - { - name: "should not squash rules with port ranges", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with port ranges should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with specific ports", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with specific ports should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with legacy port field", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - }, - expectedCount: 4, - description: "Rules with legacy port field should not be squashed", - }, - { - name: "should not squash rules with DROP action", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "Rules with DROP action should not be squashed", - }, - { - name: "should squash rules without port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 1, - description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule", - }, - { - name: "mixed rules should not squash protocol with port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "TCP should not be squashed because one rule has port restrictions", - }, - { - name: "should squash UDP but not TCP when TCP has port restrictions", - rules: []*mgmProto.FirewallRule{ - // TCP rules with port restrictions - should NOT be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - // UDP rules without port restrictions - SHOULD be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0) - description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: tt.rules, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - - assert.Equal(t, tt.expectedCount, len(rules), tt.description) - - // For squashed rules, verify we get the expected 0.0.0.0 rule - if tt.expectedCount == 1 { - assert.Equal(t, "0.0.0.0", rules[0].PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action) - } - }) - } -} - func TestPortInfoEmpty(t *testing.T) { tests := []struct { name string diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index ec920c5f3..442f54e71 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/internal/debug/wgshow.go b/client/internal/debug/wgshow.go index e4b4c2368..8233ca510 100644 --- a/client/internal/debug/wgshow.go +++ b/client/internal/debug/wgshow.go @@ -14,6 +14,9 @@ type WGIface interface { } func (g *BundleGenerator) addWgShow() error { + if g.statusRecorder == nil { + return fmt.Errorf("no status recorder available for wg show") + } result, err := g.statusRecorder.PeersStatus() if err != nil { return err diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..74111d335 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -197,6 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,9 +209,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } @@ -273,9 +275,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 0e8a53a63..d9854c033 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -31,6 +31,7 @@ const ( systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC" + systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS" systemdDbusResolvConfModeForeign = "foreign" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" @@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana log.Warnf("failed to set DNSSEC to 'no': %v", err) } + // We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off + if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil { + log.Warnf("failed to set DNSOverTLS to 'no': %v", err) + } + var ( searchDomains []string matchDomains []string diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go new file mode 100644 index 000000000..43fe2d020 --- /dev/null +++ b/client/internal/dnsfwd/cache.go @@ -0,0 +1,78 @@ +package dnsfwd + +import ( + "net/netip" + "slices" + "strings" + "sync" + + "github.com/miekg/dns" +) + +type cache struct { + mu sync.RWMutex + records map[string]*cacheEntry +} + +type cacheEntry struct { + ip4Addrs []netip.Addr + ip6Addrs []netip.Addr +} + +func newCache() *cache { + return &cache{ + records: make(map[string]*cacheEntry), + } +} + +func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + entry, exists := c.records[normalizeDomain(domain)] + if !exists { + return nil, false + } + + switch reqType { + case dns.TypeA: + return slices.Clone(entry.ip4Addrs), true + case dns.TypeAAAA: + return slices.Clone(entry.ip6Addrs), true + default: + return nil, false + } + +} + +func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { + c.mu.Lock() + defer c.mu.Unlock() + norm := normalizeDomain(domain) + entry, exists := c.records[norm] + if !exists { + entry = &cacheEntry{} + c.records[norm] = entry + } + + switch reqType { + case dns.TypeA: + entry.ip4Addrs = slices.Clone(addrs) + case dns.TypeAAAA: + entry.ip6Addrs = slices.Clone(addrs) + } +} + +// unset removes cached entries for the given domain and request type. +func (c *cache) unset(domain string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.records, normalizeDomain(domain)) +} + +// normalizeDomain converts an input domain into a canonical form used as cache key: +// lowercase and fully-qualified (with trailing dot). +func normalizeDomain(domain string) string { + // dns.Fqdn ensures trailing dot; ToLower for consistent casing + return dns.Fqdn(strings.ToLower(domain)) +} diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go new file mode 100644 index 000000000..c23f0f31d --- /dev/null +++ b/client/internal/dnsfwd/cache_test.go @@ -0,0 +1,86 @@ +package dnsfwd + +import ( + "net/netip" + "testing" +) + +func mustAddr(t *testing.T, s string) netip.Addr { + t.Helper() + a, err := netip.ParseAddr(s) + if err != nil { + t.Fatalf("parse addr %s: %v", s, err) + } + return a +} + +func TestCacheNormalization(t *testing.T) { + c := newCache() + + // Mixed case, without trailing dot + domainInput := "ExAmPlE.CoM" + ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")} + c.set(domainInput, 1 /* dns.TypeA */, ipv4) + + // Lookup with lower, with trailing dot + if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok) + } + + // Lookup with different casing again + if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok) + } +} + +func TestCacheSeparateTypes(t *testing.T) { + c := newCache() + + domain := "test.local" + ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")} + ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")} + + c.set(domain, 1 /* A */, ipv4) + c.set(domain, 28 /* AAAA */, ipv6) + + got4, ok4 := c.get(domain, 1) + if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] { + t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4) + } + + got6, ok6 := c.get(domain, 28) + if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] { + t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6) + } +} + +func TestCacheCloneOnGetAndSet(t *testing.T) { + c := newCache() + domain := "clone.test" + + src := []netip.Addr{mustAddr(t, "8.8.8.8")} + c.set(domain, 1, src) + + // Mutate source slice; cache should be unaffected + src[0] = mustAddr(t, "9.9.9.9") + + got, ok := c.get(domain, 1) + if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" { + t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok) + } + + // Mutate returned slice; internal cache should remain unchanged + got[0] = mustAddr(t, "4.4.4.4") + got2, ok2 := c.get(domain, 1) + if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" { + t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2) + } +} + +func TestCacheMiss(t *testing.T) { + c := newCache() + if got, ok := c.get("missing.example", 1); ok || got != nil { + t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) + } +} + diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index d912919a1..7a262fa4c 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -46,6 +46,7 @@ type DNSForwarder struct { fwdEntries []*ForwarderEntry firewall firewaller resolver resolver + cache *cache } func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { @@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat firewall: firewall, statusRecorder: statusRecorder, resolver: net.DefaultResolver, + cache: newCache(), } } @@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() + // remove cache entries for domains that no longer appear + f.removeStaleCacheEntries(f.fwdEntries, entries) + f.fwdEntries = entries log.Debugf("Updated DNS forwarder with %d domains", len(entries)) } +// removeStaleCacheEntries unsets cache items for domains that were present +// in the old list but not present in the new list. +func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) { + if f.cache == nil { + return + } + + newSet := make(map[string]struct{}, len(newEntries)) + for _, e := range newEntries { + if e == nil { + continue + } + newSet[e.Domain.PunycodeString()] = struct{}{} + } + + for _, e := range oldEntries { + if e == nil { + continue + } + pattern := e.Domain.PunycodeString() + if _, ok := newSet[pattern]; !ok { + f.cache.unset(pattern) + } + } +} + func (f *DNSForwarder) Close(ctx context.Context) error { var result *multierror.Error @@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) + f.cache.set(domain, question.Qtype, ips) return resp } @@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns resp.Rcode = dns.RcodeSuccess } -// handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) { +// handleDNSError processes DNS lookup errors and sends an appropriate error response. +func (f *DNSForwarder) handleDNSError( + ctx context.Context, + w dns.ResponseWriter, + question dns.Question, + resp *dns.Msg, + domain string, + err error, +) { + // Default to SERVFAIL; override below when appropriate. + resp.Rcode = dns.RcodeServerFailure + + qType := question.Qtype + qTypeName := dns.TypeToString[qType] + + // Prefer typed DNS errors; fall back to generic logging otherwise. var dnsErr *net.DNSError - - switch { - case errors.As(err, &dnsErr): - resp.Rcode = dns.RcodeServerFailure - if dnsErr.IsNotFound { - f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype) + if !errors.As(err, &dnsErr) { + log.Warnf(errResolveFailed, domain, err) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } + return + } - if dnsErr.Server != "" { - log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err) - } else { - log.Warnf(errResolveFailed, domain, err) + // NotFound: set NXDOMAIN / appropriate code via helper. + if dnsErr.IsNotFound { + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } - default: - resp.Rcode = dns.RcodeServerFailure + f.cache.set(domain, question.Qtype, nil) + return + } + + // Upstream failed but we might have a cached answer—serve it if present. + if ips, ok := f.cache.get(domain, qType); ok { + if len(ips) > 0 { + log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) + f.addIPsToResponse(resp, domain, ips) + resp.Rcode = dns.RcodeSuccess + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write cached DNS response: %v", writeErr) + } + } else { // send NXDOMAIN / appropriate code if cache is empty + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) + } + } + return + } + + // No cache. Log with or without the server field for more context. + if dnsErr.Server != "" { + log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err) + } else { log.Warnf(errResolveFailed, domain, err) } - if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write failure DNS response: %v", err) + // Write final failure response. + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 57085e19a..c1c95a2c1 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") } +// Ensures that when the first query succeeds and populates the cache, +// a subsequent upstream failure still returns a successful response from cache. +func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("example.com") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("1.2.3.4") + + // First call resolves successfully and populates cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{ip}, nil).Once() + + // Second call fails upstream; forwarder should serve from cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + // First query: populate cache + q1 := &dns.Msg{} + q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Second query: serve from cache after upstream failure + q2 := &dns.Msg{} + q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + +// Verifies that cache normalization works across casing and trailing dot variations. +func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("ExAmPlE.CoM") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("9.8.7.6") + + // Initial resolution with mixed case to populate cache + mixedQuery := "ExAmPlE.CoM" + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))). + Return([]netip.Addr{ip}, nil).Once() + + q1 := &dns.Msg{} + q1.SetQuestion(mixedQuery+".", dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Subsequent query without dot and upper case should hit cache even if upstream fails + // Forwarder lowercases and uses the question name as-is (no trailing dot here) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + q2 := &dns.Msg{} + q2.SetQuestion("EXAMPLE.COM", dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp) + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { // Test complex overlapping pattern scenarios mockFirewall := &MockFirewall{} diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 5c7a3fbdd..a3a4ba40f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -40,7 +40,6 @@ type Manager struct { fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder - port uint16 } func ListenPort() uint16 { @@ -49,11 +48,16 @@ func ListenPort() uint16 { return listenPort } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { +func SetListenPort(port uint16) { + listenPortMu.Lock() + listenPort = port + listenPortMu.Unlock() +} + +func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, - port: port, } } @@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.port > 0 { - listenPortMu.Lock() - listenPort = m.port - listenPortMu.Unlock() - } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 9f36449fa..ac559d2b4 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1905,6 +1905,10 @@ func (e *Engine) updateDNSForwarder( return } + if forwarderPort > 0 { + dnsfwd.SetListenPort(forwarderPort) + } + if !enabled { if e.dnsForwardMgr == nil { return @@ -1918,7 +1922,7 @@ func (e *Engine) updateDNSForwarder( if len(fwdEntries) > 0 { switch { case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil @@ -1948,7 +1952,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { log.Errorf("failed to stop DNS forward: %v", err) } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 344104405..2f1098100 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -105,6 +105,10 @@ type MockWGIface struct { LastActivitiesFunc func() map[string]monotime.Time } +func (m *MockWGIface) RemoveEndpointAddress(_ string) error { + return nil +} + func (m *MockWGIface) FullStats() (*configurer.Stats, error) { return nil, fmt.Errorf("not implemented") } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 690fdb7cc..98fe01912 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -28,6 +28,7 @@ type wgIfaceBase interface { UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemoveEndpointAddress(key string) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error diff --git a/client/internal/lazyconn/activity/lazy_conn.go b/client/internal/lazyconn/activity/lazy_conn.go new file mode 100644 index 000000000..2564a9905 --- /dev/null +++ b/client/internal/lazyconn/activity/lazy_conn.go @@ -0,0 +1,82 @@ +package activity + +import ( + "context" + "io" + "net" + "time" +) + +// lazyConn detects activity when WireGuard attempts to send packets. +// It does not deliver packets, only signals that activity occurred. +type lazyConn struct { + activityCh chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// newLazyConn creates a new lazyConn for activity detection. +func newLazyConn() *lazyConn { + ctx, cancel := context.WithCancel(context.Background()) + return &lazyConn{ + activityCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } +} + +// Read blocks until the connection is closed. +func (c *lazyConn) Read(_ []byte) (n int, err error) { + <-c.ctx.Done() + return 0, io.EOF +} + +// Write signals activity detection when ICEBind routes packets to this endpoint. +func (c *lazyConn) Write(b []byte) (n int, err error) { + if c.ctx.Err() != nil { + return 0, io.EOF + } + + select { + case c.activityCh <- struct{}{}: + default: + } + + return len(b), nil +} + +// ActivityChan returns the channel that signals when activity is detected. +func (c *lazyConn) ActivityChan() <-chan struct{} { + return c.activityCh +} + +// Close closes the connection. +func (c *lazyConn) Close() error { + c.cancel() + return nil +} + +// LocalAddr returns the local address. +func (c *lazyConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// RemoteAddr returns the remote address. +func (c *lazyConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// SetDeadline sets the read and write deadlines. +func (c *lazyConn) SetDeadline(_ time.Time) error { + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (c *lazyConn) SetReadDeadline(_ time.Time) error { + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (c *lazyConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go new file mode 100644 index 000000000..792d04215 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -0,0 +1,127 @@ +package activity + +import ( + "fmt" + "net" + "net/netip" + "sync" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +type bindProvider interface { + GetBind() device.EndpointManager +} + +const ( + // lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers. + // The actual routing is done via fakeIP in ICEBind, not by this port. + lazyBindPort = 17473 +) + +// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode. +type BindListener struct { + wgIface WgInterface + peerCfg lazyconn.PeerConfig + done sync.WaitGroup + + lazyConn *lazyConn + bind device.EndpointManager + fakeIP netip.Addr +} + +// NewBindListener creates a listener that passes data directly through bind using LazyConn. +// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range. +func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) { + fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs) + if err != nil { + return nil, fmt.Errorf("derive fake IP: %w", err) + } + + d := &BindListener{ + wgIface: wgIface, + peerCfg: cfg, + bind: bind, + fakeIP: fakeIP, + } + + if err := d.setupLazyConn(); err != nil { + return nil, fmt.Errorf("setup lazy connection: %v", err) + } + + d.done.Add(1) + return d, nil +} + +// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. +// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). +// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { + if len(allowedIPs) == 0 { + return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") + } + + ourNetwork := wgIface.Address().Network + + var peerIP netip.Addr + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if !ip.Is4() { + continue + } + if ourNetwork.Contains(ip) { + peerIP = ip + break + } + } + + if !peerIP.IsValid() { + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + } + + octets := peerIP.As4() + fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) + return fakeIP, nil +} + +func (d *BindListener) setupLazyConn() error { + d.lazyConn = newLazyConn() + d.bind.SetEndpoint(d.fakeIP, d.lazyConn) + + endpoint := &net.UDPAddr{ + IP: d.fakeIP.AsSlice(), + Port: lazyBindPort, + } + return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil) +} + +// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed. +func (d *BindListener) ReadPackets() { + select { + case <-d.lazyConn.ActivityChan(): + d.peerCfg.Log.Infof("activity detected via LazyConn") + case <-d.lazyConn.ctx.Done(): + d.peerCfg.Log.Infof("exit from activity listener") + } + + d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey) + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { + d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) + } + + _ = d.lazyConn.Close() + d.bind.RemoveEndpoint(d.fakeIP) + d.done.Done() +} + +// Close stops the listener and cleans up resources. +func (d *BindListener) Close() { + d.peerCfg.Log.Infof("closing activity listener (LazyConn)") + + if err := d.lazyConn.Close(); err != nil { + d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err) + } + + d.done.Wait() +} diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go new file mode 100644 index 000000000..f86dd3877 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -0,0 +1,291 @@ +package activity + +import ( + "net" + "net/netip" + "runtime" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/lazyconn" + peerid "github.com/netbirdio/netbird/client/internal/peer/id" +) + +func isBindListenerPlatform() bool { + return runtime.GOOS == "windows" || runtime.GOOS == "js" +} + +// mockEndpointManager implements device.EndpointManager for testing +type mockEndpointManager struct { + endpoints map[netip.Addr]net.Conn +} + +func newMockEndpointManager() *mockEndpointManager { + return &mockEndpointManager{ + endpoints: make(map[netip.Addr]net.Conn), + } +} + +func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + m.endpoints[fakeIP] = conn +} + +func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) { + delete(m.endpoints, fakeIP) +} + +func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn { + return m.endpoints[fakeIP] +} + +// MockWGIfaceBind mocks WgInterface with bind support +type MockWGIfaceBind struct { + endpointMgr *mockEndpointManager +} + +func (m *MockWGIfaceBind) RemovePeer(string) error { + return nil +} + +func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { + return nil +} + +func (m *MockWGIfaceBind) IsUserspaceBind() bool { + return true +} + +func (m *MockWGIfaceBind) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +func (m *MockWGIfaceBind) GetBind() device.EndpointManager { + return m.endpointMgr +} + +func TestBindListener_Creation(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + expectedFakeIP := netip.MustParseAddr("127.2.0.2") + conn := mockEndpointMgr.GetEndpoint(expectedFakeIP) + require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager") + + _, ok := conn.(*lazyConn) + assert.True(t, ok, "Registered endpoint should be a lazyConn") + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestBindListener_ActivityDetection(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + fakeIP := listener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection") +} + +func TestBindListener_Close(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + fakeIP := listener.fakeIP + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close") +} + +func TestManager_BindMode(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + err := mgr.MonitorPeerActivity(cfg) + require.NoError(t, err) + + listener, exists := mgr.GetPeerListener(cfg.PeerConnID) + require.True(t, exists, "Peer listener should be found") + + bindListener, ok := listener.(*BindListener) + require.True(t, ok, "Listener should be BindListener, got %T", listener) + + fakeIP := bindListener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case peerConnID := <-mgr.OnActivityChan: + assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notification") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity") +} + +func TestManager_BindMode_MultiplePeers(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer1 := &MocPeer{PeerID: "testPeer1"} + peer2 := &MocPeer{PeerID: "testPeer2"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg1 := lazyconn.PeerConfig{ + PublicKey: peer1.PeerID, + PeerConnID: peer1.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + cfg2 := lazyconn.PeerConfig{ + PublicKey: peer2.PeerID, + PeerConnID: peer2.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")}, + Log: log.WithField("peer", "testPeer2"), + } + + err := mgr.MonitorPeerActivity(cfg1) + require.NoError(t, err) + + err = mgr.MonitorPeerActivity(cfg2) + require.NoError(t, err) + + listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID) + require.True(t, exists, "Peer1 listener should be found") + bindListener1 := listener1.(*BindListener) + + listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID) + require.True(t, exists, "Peer2 listener should be found") + bindListener2 := listener2.(*BindListener) + + conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP) + require.NotNil(t, conn1, "Peer1 endpoint should be registered") + _, err = conn1.Write([]byte{0x01}) + require.NoError(t, err) + + conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP) + require.NotNil(t, conn2, "Peer2 endpoint should be registered") + _, err = conn2.Write([]byte{0x02}) + require.NoError(t, err) + + receivedPeers := make(map[peerid.ConnID]bool) + for i := 0; i < 2; i++ { + select { + case peerConnID := <-mgr.OnActivityChan: + receivedPeers[peerConnID] = true + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notifications") + } + } + + assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received") + assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received") +} diff --git a/client/internal/lazyconn/activity/listener_test.go b/client/internal/lazyconn/activity/listener_test.go deleted file mode 100644 index 98d7838d2..000000000 --- a/client/internal/lazyconn/activity/listener_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package activity - -import ( - "testing" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/lazyconn" -) - -func TestNewListener(t *testing.T) { - peer := &MocPeer{ - PeerID: "examplePublicKey1", - } - - cfg := lazyconn.PeerConfig{ - PublicKey: peer.PeerID, - PeerConnID: peer.ConnID(), - Log: log.WithField("peer", "examplePublicKey1"), - } - - l, err := NewListener(MocWGIface{}, cfg) - if err != nil { - t.Fatalf("failed to create listener: %v", err) - } - - chanClosed := make(chan struct{}) - go func() { - defer close(chanClosed) - l.ReadPackets() - }() - - time.Sleep(1 * time.Second) - l.Close() - - select { - case <-chanClosed: - case <-time.After(time.Second): - } -} diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener_udp.go similarity index 64% rename from client/internal/lazyconn/activity/listener.go rename to client/internal/lazyconn/activity/listener_udp.go index 817ff00c3..e0b09be6c 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener_udp.go @@ -11,26 +11,27 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" ) -// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking -type Listener struct { +// UDPListener uses UDP sockets for activity detection in kernel mode. +type UDPListener struct { wgIface WgInterface peerCfg lazyconn.PeerConfig conn *net.UDPConn endpoint *net.UDPAddr done sync.Mutex - isClosed atomic.Bool // use to avoid error log when closing the listener + isClosed atomic.Bool } -func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { - d := &Listener{ +// NewUDPListener creates a listener that detects activity via UDP socket reads. +func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) { + d := &UDPListener{ wgIface: wgIface, peerCfg: cfg, } conn, err := d.newConn() if err != nil { - return nil, fmt.Errorf("failed to creating activity listener: %v", err) + return nil, fmt.Errorf("create UDP connection: %v", err) } d.conn = conn d.endpoint = conn.LocalAddr().(*net.UDPAddr) @@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error if err := d.createEndpoint(); err != nil { return nil, err } + d.done.Lock() - cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String()) + cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String()) return d, nil } -func (d *Listener) ReadPackets() { +// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed. +func (d *UDPListener) ReadPackets() { for { n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) if err != nil { @@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() { } d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) - if err := d.removeEndpoint(); err != nil { + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) } - _ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection" + // Ignore close error as it may return "use of closed network connection" if already closed. + _ = d.conn.Close() d.done.Unlock() } -func (d *Listener) Close() { +// Close stops the listener and cleans up resources. +func (d *UDPListener) Close() { d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.isClosed.Store(true) @@ -82,16 +87,12 @@ func (d *Listener) Close() { d.done.Lock() } -func (d *Listener) removeEndpoint() error { - return d.wgIface.RemovePeer(d.peerCfg.PublicKey) -} - -func (d *Listener) createEndpoint() error { +func (d *UDPListener) createEndpoint() error { d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String()) return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil) } -func (d *Listener) newConn() (*net.UDPConn, error) { +func (d *UDPListener) newConn() (*net.UDPConn, error) { addr := &net.UDPAddr{ Port: 0, IP: listenIP, diff --git a/client/internal/lazyconn/activity/listener_udp_test.go b/client/internal/lazyconn/activity/listener_udp_test.go new file mode 100644 index 000000000..d2adb9bf4 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_udp_test.go @@ -0,0 +1,110 @@ +package activity + +import ( + "net" + "net/netip" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +func TestUDPListener_Creation(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + require.NotNil(t, listener.conn) + require.NotNil(t, listener.endpoint) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestUDPListener_ActivityDetection(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + conn, err := net.Dial("udp", listener.conn.LocalAddr().String()) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } +} + +func TestUDPListener_Close(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed") +} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 915fb9cb8..db283ec9a 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -1,21 +1,32 @@ package activity import ( + "errors" "net" "net/netip" + "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) +// listener defines the contract for activity detection listeners. +type listener interface { + ReadPackets() + Close() +} + type WgInterface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + IsUserspaceBind() bool + Address() wgaddr.Address } type Manager struct { @@ -23,7 +34,7 @@ type Manager struct { wgIface WgInterface - peers map[peerid.ConnID]*Listener + peers map[peerid.ConnID]listener done chan struct{} mu sync.Mutex @@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager { m := &Manager{ OnActivityChan: make(chan peerid.ConnID, 1), wgIface: wgIface, - peers: make(map[peerid.ConnID]*Listener), + peers: make(map[peerid.ConnID]listener), done: make(chan struct{}), } return m @@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error { return nil } - listener, err := NewListener(m.wgIface, peerCfg) + listener, err := m.createListener(peerCfg) if err != nil { return err } - m.peers[peerCfg.PeerConnID] = listener + m.peers[peerCfg.PeerConnID] = listener go m.waitForTraffic(listener, peerCfg.PeerConnID) return nil } +func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) { + if !m.wgIface.IsUserspaceBind() { + return NewUDPListener(m.wgIface, peerCfg) + } + + // BindListener is only used on Windows and JS platforms: + // - JS: Cannot listen to UDP sockets + // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default + // gateway points to, preventing them from reaching the loopback interface. + // BindListener bypasses this by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" { + return NewUDPListener(m.wgIface, peerCfg) + } + + provider, ok := m.wgIface.(bindProvider) + if !ok { + return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") + } + + return NewBindListener(m.wgIface, provider.GetBind(), peerCfg) +} + func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) { m.mu.Lock() defer m.mu.Unlock() @@ -82,8 +115,8 @@ func (m *Manager) Close() { } } -func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) { - listener.ReadPackets() +func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) { + l.ReadPackets() m.mu.Lock() if _, ok := m.peers[peerConnID]; !ok { diff --git a/client/internal/lazyconn/activity/manager_test.go b/client/internal/lazyconn/activity/manager_test.go index ae6c31da4..0768d9219 100644 --- a/client/internal/lazyconn/activity/manager_test.go +++ b/client/internal/lazyconn/activity/manager_test.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) @@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { return nil - } -// Add this method to the Manager struct -func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { +func (m MocWGIface) IsUserspaceBind() bool { + return false +} + +func (m MocWGIface) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +// GetPeerListener is a test helper to access listeners +func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) { m.mu.Lock() defer m.mu.Unlock() - listener, exists := m.peers[peerConnID] - return listener, exists + l, exists := m.peers[peerConnID] + return l, exists } func TestManager_MonitorPeerActivity(t *testing.T) { @@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) { t.Fatalf("peer listener not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + // Get the UDP listener's address for triggering + udpListener, ok := listener.(*UDPListener) + if !ok { + t.Fatalf("expected UDPListener") + } + if err := trigger(udpListener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() + listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID) + udpListener, _ := listener.(*UDPListener) + addr := udpListener.conn.LocalAddr().String() mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) @@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer1 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener1, _ := listener.(*UDPListener) + if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer2 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener2, _ := listener.(*UDPListener) + if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index 0351904f7..0626c1815 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -7,6 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/monotime" ) @@ -14,5 +15,6 @@ type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error IsUserspaceBind() bool + Address() wgaddr.Address LastActivities() map[string]monotime.Time } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8db9e58f4..68afe986a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) - conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) + conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) if !isForceRelayed() { - conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) + conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher) @@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() { } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.currentConnPriority = conntype.None + if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil { + conn.Log.Errorf("failed to remove wg endpoint: %v", err) + } } changed := conn.statusICE.Get() != worker.StatusDisconnected @@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() { if conn.currentConnPriority == conntype.Relay { conn.Log.Debugf("clean up WireGuard config") conn.currentConnPriority = conntype.None + if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil { + conn.Log.Errorf("failed to remove wg endpoint: %v", err) + } } if conn.wgProxyRelay != nil { diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index c839ab147..6b47f95eb 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) { return } - onNewOffeChan := make(chan struct{}) + onNewOfferChan := make(chan struct{}) - conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { - onNewOffeChan <- struct{}{} + conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { + onNewOfferChan <- struct{}{} }) conn.OnRemoteOffer(OfferAnswer{ @@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { defer cancel() select { - case <-onNewOffeChan: + case <-onNewOfferChan: // success case <-ctx.Done(): t.Error("expected to receive a new offer notification, but timed out") @@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) { return } - onNewOffeChan := make(chan struct{}) + onNewOfferChan := make(chan struct{}) - conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { - onNewOffeChan <- struct{}{} + conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { + onNewOfferChan <- struct{}{} }) conn.OnRemoteAnswer(OfferAnswer{ @@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { defer cancel() select { - case <-onNewOffeChan: + case <-onNewOfferChan: // success case <-ctx.Done(): t.Error("expected to receive a new offer notification, but timed out") diff --git a/client/internal/peer/guard/env.go b/client/internal/peer/guard/env.go new file mode 100644 index 000000000..1ea2d21be --- /dev/null +++ b/client/internal/peer/guard/env.go @@ -0,0 +1,20 @@ +package guard + +import ( + "os" + "strconv" + "time" +) + +const ( + envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD" +) + +func GetICEMonitorPeriod() time.Duration { + if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" { + if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return defaultCandidatesMonitorPeriod +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 09cf9ae63..0f22ee7b0 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -16,8 +16,8 @@ import ( ) const ( - candidatesMonitorPeriod = 5 * time.Minute - candidateGatheringTimeout = 5 * time.Second + defaultCandidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second ) type ICEMonitor struct { @@ -25,16 +25,19 @@ type ICEMonitor struct { iFaceDiscover stdnet.ExternalIFaceDiscover iceConfig icemaker.Config + tickerPeriod time.Duration currentCandidatesAddress []string candidatesMu sync.Mutex } -func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { +func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor { + log.Debugf("prepare ICE monitor with period: %s", period) cm := &ICEMonitor{ ReconnectCh: make(chan struct{}, 1), iFaceDiscover: iFaceDiscover, iceConfig: config, + tickerPeriod: period, } return cm } @@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { return } - ticker := time.NewTicker(candidatesMonitorPeriod) + // Initial check to populate the candidates for later comparison + if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil { + log.Warnf("Failed to check initial ICE candidates: %v", err) + } + + ticker := time.NewTicker(cm.tickerPeriod) defer ticker.Stop() for { diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 90e45426f..686430752 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -51,7 +51,7 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) go iceMonitor.Start(ctx, w.onICEChanged) w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 42eaea683..aff26f847 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -44,13 +44,19 @@ type OfferAnswer struct { } type Handshaker struct { - mu sync.Mutex - log *log.Entry - config ConnConfig - signaler *Signaler - ice *WorkerICE - relay *WorkerRelay - onNewOfferListeners []*OfferListener + mu sync.Mutex + log *log.Entry + config ConnConfig + signaler *Signaler + ice *WorkerICE + relay *WorkerRelay + // relayListener is not blocking because the listener is using a goroutine to process the messages + // and it will only keep the latest message if multiple offers are received in a short time + // this is to avoid blocking the handshaker if the listener is doing some heavy processing + // and also to avoid processing old offers if multiple offers are received in a short time + // the listener will always process the latest offer + relayListener *AsyncOfferListener + iceListener func(remoteOfferAnswer *OfferAnswer) // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer @@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W } } -func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { - l := NewOfferListener(offer) - h.onNewOfferListeners = append(h.onNewOfferListeners, l) +func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.relayListener = NewAsyncOfferListener(offer) +} + +func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.iceListener = offer } func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - // received confirmation from the remote peer -> ready to proceed + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + if h.relayListener != nil { + h.relayListener.Notify(&remoteOfferAnswer) + } + + if h.iceListener != nil { + h.iceListener(&remoteOfferAnswer) + } + if err := h.sendAnswer(); err != nil { h.log.Errorf("failed to send remote offer confirmation: %s", err) continue } - for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) - } - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) - for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) + if h.relayListener != nil { + h.relayListener.Notify(&remoteOfferAnswer) + } + + if h.iceListener != nil { + h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): h.log.Infof("stop listening for remote offers and answers") diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go index e2d3f3f38..772e2777f 100644 --- a/client/internal/peer/handshaker_listener.go +++ b/client/internal/peer/handshaker_listener.go @@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string { return oa.SessionID.String() } -type OfferListener struct { +type AsyncOfferListener struct { fn callbackFunc running bool latest *OfferAnswer mu sync.Mutex } -func NewOfferListener(fn callbackFunc) *OfferListener { - return &OfferListener{ +func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener { + return &AsyncOfferListener{ fn: fn, } } -func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { +func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) { o.mu.Lock() defer o.mu.Unlock() diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go index 8363741a5..1a7156d10 100644 --- a/client/internal/peer/handshaker_listener_test.go +++ b/client/internal/peer/handshaker_listener_test.go @@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) { runChan <- struct{}{} } - hl := NewOfferListener(longRunningFn) + hl := NewAsyncOfferListener(longRunningFn) hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer) diff --git a/client/internal/peer/iface.go b/client/internal/peer/iface.go index 0bcc7a68e..678396e61 100644 --- a/client/internal/peer/iface.go +++ b/client/internal/peer/iface.go @@ -18,4 +18,5 @@ type WGIface interface { GetStats() (map[string]configurer.WGStats, error) GetProxy() wgproxy.Proxy Address() wgaddr.Address + RemoveEndpointAddress(key string) error } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index eb886a4d3..3675f0157 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn * func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString()) w.muxAgent.Lock() + defer w.muxAgent.Unlock() - if w.agentConnecting { - w.log.Debugf("agent connection is in progress, skipping the offer") - w.muxAgent.Unlock() - return - } - - if w.agent != nil { + if w.agent != nil || w.agentConnecting { // backward compatibility with old clients that do not send session ID if remoteOfferAnswer.SessionID == nil { w.log.Debugf("agent already exists, skipping the offer") - w.muxAgent.Unlock() return } if w.remoteSessionID == *remoteOfferAnswer.SessionID { w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString()) - w.muxAgent.Unlock() return } w.log.Debugf("agent already exists, recreate the connection") @@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } + + sessionID, err := NewICESessionID() + if err != nil { + w.log.Errorf("failed to create new session ID: %s", err) + } + w.sessionID = sessionID w.agent = nil } @@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { preferredCandidateTypes = icemaker.CandidateTypes() } - w.log.Debugf("recreate ICE agent") + if remoteOfferAnswer.SessionID != nil { + w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID) + } dialerCtx, dialerCancel := context.WithCancel(w.ctx) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes) if err != nil { w.log.Errorf("failed to recreate ICE Agent: %s", err) - w.muxAgent.Unlock() return } w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true - w.muxAgent.Unlock() + if remoteOfferAnswer.SessionID != nil { + w.remoteSessionID = *remoteOfferAnswer.SessionID + } else { + w.remoteSessionID = "" + } go w.connect(dialerCtx, agent, remoteOfferAnswer) } @@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.muxAgent.Lock() w.agentConnecting = false w.lastSuccess = time.Now() - if remoteOfferAnswer.SessionID != nil { - w.remoteSessionID = *remoteOfferAnswer.SessionID - } w.muxAgent.Unlock() // todo: the potential problem is a race between the onConnectionStateChange @@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C } w.muxAgent.Lock() - // todo review does it make sense to generate new session ID all the time when w.agent==agent - sessionID, err := NewICESessionID() - if err != nil { - w.log.Errorf("failed to create new session ID: %s", err) - } - w.sessionID = sessionID if w.agent == agent { + // consider to remove from here and move to the OnNewOffer + sessionID, err := NewICESessionID() + if err != nil { + w.log.Errorf("failed to create new session ID: %s", err) + } + w.sessionID = sessionID w.agent = nil w.agentConnecting = false + w.remoteSessionID = "" } w.muxAgent.Unlock() } @@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // notify the conn.onICEStateDisconnected changes to update the current used priority + w.closeAgent(agent, dialerCancel) + if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected w.conn.onICEStateDisconnected() } - w.closeAgent(agent, dialerCancel) default: return } diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 4e6b422f6..f03822089 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), + WgPort: iface.DefaultWgPort, } if _, err := config.apply(input); err != nil { diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 45e37bf0e..90bde7707 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -5,11 +5,14 @@ import ( "errors" "os" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/util" ) @@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) { } } +func TestNewProfileDefaults(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + }) + require.NoError(t, err, "should create new config") + + assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default") + assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default") + assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated") + assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated") + assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default") + assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820") + assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default") + assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default") + assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set") + assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set") + assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults") + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS") + assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS") + } +} + +func TestWireguardPortZeroExplicit(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // Create a new profile with explicit port 0 (random port) + explicitZero := 0 + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: &explicitZero, + }) + require.NoError(t, err, "should create config with explicit port 0") + + assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user") + + // Verify it persists + readConfig, err := GetConfig(configPath) + require.NoError(t, err) + assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file") +} + +func TestWireguardPortDefaultVsExplicit(t *testing.T) { + tests := []struct { + name string + wireguardPort *int + expectedPort int + description string + }{ + { + name: "no port specified uses default", + wireguardPort: nil, + expectedPort: iface.DefaultWgPort, + description: "When user doesn't specify port, default to 51820", + }, + { + name: "explicit zero for random port", + wireguardPort: func() *int { v := 0; return &v }(), + expectedPort: 0, + description: "When user explicitly sets 0, use 0 for random port", + }, + { + name: "explicit custom port", + wireguardPort: func() *int { v := 52000; return &v }(), + expectedPort: 52000, + description: "When user sets custom port, use that port", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: tt.wireguardPort, + }) + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedPort, config.WgPort, tt.description) + }) + } +} + func TestUpdateOldManagementURL(t *testing.T) { tests := []struct { name string diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} diff --git a/client/server/server.go b/client/server/server.go index e6de608c5..89f50a1ef 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.CustomDNSAddress = []byte{} } + config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + + if msg.DnsRouteInterval != nil { + interval := msg.DnsRouteInterval.AsDuration() + config.DNSRouteInterval = &interval + } + config.RosenpassEnabled = msg.RosenpassEnabled config.RosenpassPermissive = msg.RosenpassPermissive config.DisableAutoConnect = msg.DisableAutoConnect diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go new file mode 100644 index 000000000..1260bcc78 --- /dev/null +++ b/client/server/setconfig_test.go @@ -0,0 +1,298 @@ +package server + +import ( + "context" + "os/user" + "path/filepath" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" +) + +// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config. +// This test uses reflection to detect when new fields are added but not handled in SetConfig. +func TestSetConfig_AllFieldsSaved(t *testing.T) { + tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origDefaultConfigPath := profilemanager.DefaultConfigPath + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.ConfigDirOverride = tempDir + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json") + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.DefaultConfigPath = origDefaultConfigPath + profilemanager.ConfigDirOverride = "" + }) + + currUser, err := user.Current() + require.NoError(t, err) + + profName := "test-profile" + + ic := profilemanager.ConfigInput{ + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: "https://api.netbird.io:443", + } + _, err = profilemanager.UpdateOrCreateConfig(ic) + require.NoError(t, err) + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + }) + require.NoError(t, err) + + ctx := context.Background() + s := New(ctx, "console", "", false, false) + + rosenpassEnabled := true + rosenpassPermissive := true + serverSSHAllowed := true + interfaceName := "utun100" + wireguardPort := int64(51820) + preSharedKey := "test-psk" + disableAutoConnect := true + networkMonitor := true + disableClientRoutes := true + disableServerRoutes := true + disableDNS := true + disableFirewall := true + blockLANAccess := true + disableNotifications := true + lazyConnectionEnabled := true + blockInbound := true + mtu := int64(1280) + + req := &proto.SetConfigRequest{ + ProfileName: profName, + Username: currUser.Username, + ManagementUrl: "https://new-api.netbird.io:443", + AdminURL: "https://new-admin.netbird.io", + RosenpassEnabled: &rosenpassEnabled, + RosenpassPermissive: &rosenpassPermissive, + ServerSSHAllowed: &serverSSHAllowed, + InterfaceName: &interfaceName, + WireguardPort: &wireguardPort, + OptionalPreSharedKey: &preSharedKey, + DisableAutoConnect: &disableAutoConnect, + NetworkMonitor: &networkMonitor, + DisableClientRoutes: &disableClientRoutes, + DisableServerRoutes: &disableServerRoutes, + DisableDns: &disableDNS, + DisableFirewall: &disableFirewall, + BlockLanAccess: &blockLANAccess, + DisableNotifications: &disableNotifications, + LazyConnectionEnabled: &lazyConnectionEnabled, + BlockInbound: &blockInbound, + NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, + CleanNATExternalIPs: false, + CustomDNSAddress: []byte("1.1.1.1:53"), + ExtraIFaceBlacklist: []string{"eth1", "eth2"}, + DnsLabels: []string{"label1", "label2"}, + CleanDNSLabels: false, + DnsRouteInterval: durationpb.New(2 * time.Minute), + Mtu: &mtu, + } + + _, err = s.SetConfig(ctx, req) + require.NoError(t, err) + + profState := profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + } + cfgPath, err := profState.FilePath() + require.NoError(t, err) + + cfg, err := profilemanager.GetConfig(cfgPath) + require.NoError(t, err) + + require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String()) + require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String()) + require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled) + require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive) + require.NotNil(t, cfg.ServerSSHAllowed) + require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed) + require.Equal(t, interfaceName, cfg.WgIface) + require.Equal(t, int(wireguardPort), cfg.WgPort) + require.Equal(t, preSharedKey, cfg.PreSharedKey) + require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect) + require.NotNil(t, cfg.NetworkMonitor) + require.Equal(t, networkMonitor, *cfg.NetworkMonitor) + require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes) + require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes) + require.Equal(t, disableDNS, cfg.DisableDNS) + require.Equal(t, disableFirewall, cfg.DisableFirewall) + require.Equal(t, blockLANAccess, cfg.BlockLANAccess) + require.NotNil(t, cfg.DisableNotifications) + require.Equal(t, disableNotifications, *cfg.DisableNotifications) + require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) + require.Equal(t, blockInbound, cfg.BlockInbound) + require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) + require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress) + // IFaceBlackList contains defaults + extras + require.Contains(t, cfg.IFaceBlackList, "eth1") + require.Contains(t, cfg.IFaceBlackList, "eth2") + require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) + require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) + require.Equal(t, uint16(mtu), cfg.MTU) + + verifyAllFieldsCovered(t, req) +} + +// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest. +// If a new field is added to SetConfigRequest, this function will fail the test, +// forcing the developer to update both the SetConfig handler and this test. +func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { + t.Helper() + + metadataFields := map[string]bool{ + "state": true, // protobuf internal + "sizeCache": true, // protobuf internal + "unknownFields": true, // protobuf internal + "Username": true, // metadata + "ProfileName": true, // metadata + "CleanNATExternalIPs": true, // control flag for clearing + "CleanDNSLabels": true, // control flag for clearing + } + + expectedFields := map[string]bool{ + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + } + + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unexpectedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + if metadataFields[fieldName] { + continue + } + + if !expectedFields[fieldName] { + unexpectedFields = append(unexpectedFields, fieldName) + } + } + + if len(unexpectedFields) > 0 { + t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields) + } +} + +// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest. +// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq. +func TestCLIFlags_MappedToSetConfig(t *testing.T) { + // Map of CLI flag names to their corresponding SetConfigRequest field names. + // This map must be updated when adding new config-related CLI flags. + flagToField := map[string]string{ + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + } + + // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). + fieldsWithoutCLIFlags := map[string]bool{ + "DisableNotifications": true, // Only settable via UI + } + + // Get all SetConfigRequest fields to verify our map is complete. + req := &proto.SetConfigRequest{} + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unmappedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + // Skip protobuf internal fields and metadata fields. + if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" { + continue + } + if fieldName == "Username" || fieldName == "ProfileName" { + continue + } + if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" { + continue + } + + // Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag. + mappedToCLI := false + for _, mappedField := range flagToField { + if mappedField == fieldName { + mappedToCLI = true + break + } + } + + hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName] + + if !mappedToCLI && !hasNoCLIFlag { + unmappedFields = append(unmappedFields, fieldName) + } + } + + if len(unmappedFields) > 0 { + t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+ + "Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+ + "add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields) + } + + t.Log("All SetConfigRequest fields are properly documented") +} diff --git a/client/status/status.go b/client/status/status.go index db5b7dc0b..5e4fcd8dc 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -205,15 +205,18 @@ func mapPeers( localICEEndpoint := "" remoteICEEndpoint := "" relayServerAddress := "" - connType := "P2P" + connType := "-" lastHandshake := time.Time{} transferReceived := int64(0) transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if pbPeerState.Relayed { - connType = "Relayed" + if isPeerConnected { + connType = "P2P" + if pbPeerState.Relayed { + connType = "Relayed" + } } if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 66e150b7d..865dd2731 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -31,7 +31,6 @@ import ( "fyne.io/systray" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -665,7 +664,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { } func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := open.Run(loginResp.VerificationURIComplete) + err := openURL(loginResp.VerificationURIComplete) if err != nil { log.Errorf("opening the verification uri in the browser failed: %v", err) return err @@ -1409,7 +1408,13 @@ func (s *serviceClient) updateConfig() error { } // showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. -func (s *serviceClient) showLoginURL() { +// It also starts a background goroutine that periodically checks if the client is already connected +// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is +// also cancelled when the window is closed. +func (s *serviceClient) showLoginURL() context.CancelFunc { + + // create a cancellable context for the background check goroutine + ctx, cancel := context.WithCancel(s.ctx) resIcon := fyne.NewStaticResource("netbird.png", iconAbout) @@ -1418,6 +1423,8 @@ func (s *serviceClient) showLoginURL() { s.wLoginURL.Resize(fyne.NewSize(400, 200)) s.wLoginURL.SetIcon(resIcon) } + // ensure goroutine is cancelled when the window is closed + s.wLoginURL.SetOnClosed(func() { cancel() }) // add a description label label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.") @@ -1498,10 +1505,46 @@ func (s *serviceClient) showLoginURL() { ) s.wLoginURL.SetContent(container.NewCenter(content)) + // start a goroutine to check connection status and close the window if connected + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return + } + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + continue + } + if status.Status == string(internal.StatusConnected) { + if s.wLoginURL != nil { + s.wLoginURL.Close() + } + return + } + } + } + }() + s.wLoginURL.Show() + + // return cancel func so callers can stop the background goroutine if desired + return cancel } func openURL(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + var err error switch runtime.GOOS { case "windows": diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..bf9839dda 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go index 4a23a4bc8..1678c3996 100644 --- a/client/wasm/internal/rdp/cert_validation.go +++ b/client/wasm/internal/rdp/cert_validation.go @@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert } } -func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { - return &tls.Config{ +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config { + config := &tls.Config{ InsecureSkipVerify: true, // We'll validate manually after handshake VerifyConnection: func(cs tls.ConnectionState) error { var certChain [][]byte @@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl return nil }, } + + // CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3 + if requiresCredSSP { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS12 + } else { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS13 + } + + return config } diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 8062a05cc..16bf63bb9 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -6,11 +6,13 @@ import ( "context" "crypto/tls" "encoding/asn1" + "errors" "fmt" "io" "net" "sync" "syscall/js" + "time" log "github.com/sirupsen/logrus" ) @@ -19,18 +21,34 @@ const ( RDCleanPathVersion = 3390 RDCleanPathProxyHost = "rdcleanpath.proxy.local" RDCleanPathProxyScheme = "ws" + + rdpDialTimeout = 15 * time.Second + + GeneralErrorCode = 1 + WSAETimedOut = 10060 + WSAEConnRefused = 10061 + WSAEConnAborted = 10053 + WSAEConnReset = 10054 + WSAEGenericError = 10050 ) type RDCleanPathPDU struct { - Version int64 `asn1:"tag:0,explicit"` - Error []byte `asn1:"tag:1,explicit,optional"` - Destination string `asn1:"utf8,tag:2,explicit,optional"` - ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` - ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` - PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` - X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` - ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` - ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` + Version int64 `asn1:"tag:0,explicit"` + Error RDCleanPathErr `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathErr struct { + ErrorCode int16 `asn1:"tag:0,explicit"` + HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"` + WSALastError int16 `asn1:"tag:2,explicit,optional"` + TLSAlertCode int8 `asn1:"tag:3,explicit,optional"` } type RDCleanPathProxy struct { @@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] destination := conn.destination log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } conn.rdpConn = rdpConn @@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] _, err = rdpConn.Write(firstPacket) if err != nil { log.Errorf("Failed to write first packet: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] n, err := rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { conn.wsHandlers.Call("send", uint8Array.Get("buffer")) } } + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + p.sendToWebSocket(conn, data) +} + +func errorToWSACode(err error) int16 { + if err == nil { + return WSAEGenericError + } + var netErr *net.OpError + if errors.As(err, &netErr) && netErr.Timeout() { + return WSAETimedOut + } + if errors.Is(err, context.DeadlineExceeded) { + return WSAETimedOut + } + if errors.Is(err, context.Canceled) { + return WSAEConnAborted + } + if errors.Is(err, io.EOF) { + return WSAEConnReset + } + return WSAEGenericError +} + +func newWSAError(err error) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + WSALastError: errorToWSACode(err), + }, + } +} + +func newHTTPError(statusCode int16) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + HTTPStatusCode: statusCode, + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go index 010efa5ea..97bb46338 100644 --- a/client/wasm/internal/rdp/rdcleanpath_handlers.go +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -3,6 +3,7 @@ package rdp import ( + "context" "crypto/tls" "encoding/asn1" "io" @@ -11,11 +12,17 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP) + protocolSSL = 0x00000001 + protocolHybridEx = 0x00000008 +) + func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) if pdu.Version != RDCleanPathVersion { - p.sendRDCleanPathError(conn, "Unsupported version") + p.sendRDCleanPathError(conn, newHTTPError(400)) return } @@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl destination = pdu.Destination } - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) - p.sendRDCleanPathError(conn, "Connection failed") + p.sendRDCleanPathError(conn, newWSAError(err)) p.cleanupConnection(conn) return } @@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl p.setupTLSConnection(conn, pdu) } +// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required. +// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags. +// Returns (requiresTLS12, selectedProtocol, detectionSuccessful). +func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) { + const minResponseLength = 19 + + if len(x224Response) < minResponseLength { + return false, 0, false + } + + // Per X.224 specification: + // x224Response[0] == 0x03: Length of X.224 header (3 bytes) + // x224Response[5] == 0xD0: X.224 Data TPDU code + if x224Response[0] != 0x03 || x224Response[5] != 0xD0 { + return false, 0, false + } + + if x224Response[11] == 0x02 { + flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 | + uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24 + + hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0 + return hasNLA, flags, true + } + + return false, 0, false +} + func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { var x224Response []byte if len(pdu.X224ConnectionPDU) > 0 { @@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) if err != nil { log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean n, err := conn.rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") + p.sendRDCleanPathError(conn, newWSAError(err)) return } x224Response = response[:n] log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) } - tlsConfig := p.getTLSConfigWithValidation(conn) + requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response) + if detected { + if requiresCredSSP { + log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol) + } else { + log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol) + } + } else { + log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3") + } + + tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP) tlsConn := tls.Client(conn.rdpConn, tlsConfig) conn.tlsConn = tlsConn if err := tlsConn.Handshake(); err != nil { log.Errorf("TLS handshake failed: %v", err) - p.sendRDCleanPathError(conn, "TLS handshake failed") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean p.cleanupConnection(conn) } -func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { - if len(pdu.X224ConnectionPDU) > 0 { - log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) - _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) - if err != nil { - log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") - return - } - - response := make([]byte, 1024) - n, err := conn.rdpConn.Read(response) - if err != nil { - log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") - return - } - - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - X224ConnectionPDU: response[:n], - ServerAddr: conn.destination, - } - - p.sendRDCleanPathPDU(conn, responsePDU) - } else { - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - ServerAddr: conn.destination, - } - p.sendRDCleanPathPDU(conn, responsePDU) - } - - go p.forwardConnToWS(conn, conn.rdpConn, "TCP") - go p.forwardWSToConn(conn, conn.rdpConn, "TCP") - - <-conn.ctx.Done() - log.Debug("TCP connection context done, cleaning up") - p.cleanupConnection(conn) -} - func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { data, err := asn1.Marshal(pdu) if err != nil { @@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean p.sendToWebSocket(conn, data) } -func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { - pdu := RDCleanPathPDU{ - Version: RDCleanPathVersion, - Error: []byte(errorMsg), - } - - data, err := asn1.Marshal(pdu) - if err != nil { - log.Errorf("Failed to marshal error PDU: %v", err) - return - } - - p.sendToWebSocket(conn, data) -} - func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { msgChan := make(chan []byte) errChan := make(chan error) diff --git a/go.mod b/go.mod index a1560b409..79dd92e6b 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 13838b82d..f0065e081 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index fb01e6867..0010974c5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -47,8 +47,9 @@ services: - traefik.enable=true - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000 + - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) + - traefik.http.routers.netbird-signal.service=netbird-signal - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index be9662345..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -621,7 +621,7 @@ renderCaddyfile() { # relay reverse_proxy /relay* relay:80 # Signal - reverse_proxy /ws-proxy/signal* signal:10000 + reverse_proxy /ws-proxy/signal* signal:80 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 # Management reverse_proxy /api/* management:80 @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", diff --git a/management/internals/server/server.go b/management/internals/server/server.go index c761a98d4..1c437e361 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/http" - "net/netip" "strings" "sync" "time" @@ -252,7 +251,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) } func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { - wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) + wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { switch { diff --git a/management/server/account/manager.go b/management/server/account/manager.go index a1ed9498b..fe9fb25c6 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -109,7 +109,7 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..df89c616c 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 66c16870b..bc352f117 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -26,9 +26,11 @@ type mockHTTPClient struct { } func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { - body, err := io.ReadAll(req.Body) - if err == nil { - c.reqBody = string(body) + if req.Body != nil { + body, err := io.ReadAll(req.Body) + if err == nil { + c.reqBody = string(body) + } } return &http.Response{ StatusCode: c.code, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 51f99b3b7..f06e57196 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr APIToken: config.ExtraConfig["ApiToken"], } return NewJumpCloudManager(jumpcloudConfig, appMetrics) + case "pocketid": + pocketidConfig := PocketIdClientConfig{ + APIToken: config.ExtraConfig["ApiToken"], + ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"], + } + return NewPocketIdManager(pocketidConfig, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go new file mode 100644 index 000000000..38a5cc67f --- /dev/null +++ b/management/server/idp/pocketid.go @@ -0,0 +1,384 @@ +package idp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type PocketIdManager struct { + managementEndpoint string + apiToken string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +type pocketIdCustomClaimDto struct { + Key string `json:"key"` + Value string `json:"value"` +} + +type pocketIdUserDto struct { + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + Disabled bool `json:"disabled"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + ID string `json:"id"` + IsAdmin bool `json:"isAdmin"` + LastName string `json:"lastName"` + LdapID string `json:"ldapId"` + Locale string `json:"locale"` + UserGroups []pocketIdUserGroupDto `json:"userGroups"` + Username string `json:"username"` +} + +type pocketIdUserCreateDto struct { + Disabled bool `json:"disabled,omitempty"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + IsAdmin bool `json:"isAdmin,omitempty"` + LastName string `json:"lastName,omitempty"` + Locale string `json:"locale,omitempty"` + Username string `json:"username"` +} + +type pocketIdPaginatedUserDto struct { + Data []pocketIdUserDto `json:"data"` + Pagination pocketIdPaginationDto `json:"pagination"` +} + +type pocketIdPaginationDto struct { + CurrentPage int `json:"currentPage"` + ItemsPerPage int `json:"itemsPerPage"` + TotalItems int `json:"totalItems"` + TotalPages int `json:"totalPages"` +} + +func (p *pocketIdUserDto) userData() *UserData { + return &UserData{ + Email: p.Email, + Name: p.DisplayName, + ID: p.ID, + AppMetadata: AppMetadata{}, + } +} + +type pocketIdUserGroupDto struct { + CreatedAt string `json:"createdAt"` + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + FriendlyName string `json:"friendlyName"` + ID string `json:"id"` + LdapID string `json:"ldapId"` + Name string `json:"name"` +} + +func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + helper := JsonParser{} + + if config.ManagementEndpoint == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing") + } + + if config.APIToken == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing") + } + + credentials := &PocketIdCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &PocketIdManager{ + managementEndpoint: config.ManagementEndpoint, + apiToken: config.APIToken, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) { + var MethodsWithBody = []string{http.MethodPost, http.MethodPut} + if !slices.Contains(MethodsWithBody, method) && body != "" { + return nil, fmt.Errorf("Body provided to unsupported method: %s", method) + } + + reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource) + if query != nil { + reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode()) + } + var req *http.Request + var err error + if body != "" { + req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body)) + } else { + req, err = http.NewRequestWithContext(ctx, method, reqURL, nil) + } + if err != nil { + return nil, err + } + + req.Header.Add("X-API-KEY", p.apiToken) + + if body != "" { + req.Header.Add("content-type", "application/json") + req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength)) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// getAllUsersPaginated fetches all users from PocketID API using pagination +func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) { + var allUsers []pocketIdUserDto + currentPage := 1 + + for { + params := url.Values{} + // Copy existing search parameters + for key, values := range searchParams { + params[key] = values + } + + params.Set("pagination[limit]", "100") + params.Set("pagination[page]", fmt.Sprintf("%d", currentPage)) + + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + var profiles pocketIdPaginatedUserDto + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + allUsers = append(allUsers, profiles.Data...) + + // Check if we've reached the last page + if currentPage >= profiles.Pagination.TotalPages { + break + } + + currentPage++ + } + + return allUsers, nil +} + +func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) { + body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + var user pocketIdUserDto + err = p.helper.Unmarshal(body, &user) + if err != nil { + return nil, err + } + + userData := user.userData() + userData.AppMetadata = appMetadata + + return userData, nil +} + +func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAccount() + } + + users := make([]*UserData, 0) + for _, profile := range allUsers { + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountId + + users = append(users, userData) + } + return users, nil +} + +func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + indexedUsers := make(map[string][]*UserData) + for _, profile := range allUsers { + userData := profile.userData() + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + } + + return indexedUsers, nil +} + +func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + firstLast := strings.Split(name, " ") + + createUser := pocketIdUserCreateDto{ + Disabled: false, + DisplayName: name, + Email: email, + FirstName: firstLast[0], + LastName: firstLast[1], + Username: firstLast[0] + "." + firstLast[1], + } + payload, err := p.helper.Marshal(createUser) + if err != nil { + return nil, err + } + + body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload)) + if err != nil { + return nil, err + } + var newUser pocketIdUserDto + err = p.helper.Unmarshal(body, &newUser) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountCreateUser() + } + var pending bool = true + ret := &UserData{ + Email: email, + Name: name, + ID: newUser.ID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pending, + WTInvitedBy: invitedByEmail, + }, + } + return ret, nil +} + +func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + params := url.Values{ + // This value a + "search": []string{email}, + } + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + var profiles struct{ data []pocketIdUserDto } + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles.data { + users = append(users, profile.userData()) + } + return users, nil +} + +func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "") + if err != nil { + return err + } + return nil +} + +func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "") + if err != nil { + return err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + +var _ Manager = (*PocketIdManager)(nil) + +type PocketIdClientConfig struct { + APIToken string + ManagementEndpoint string +} + +type PocketIdCredentials struct { + clientConfig PocketIdClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + appMetrics telemetry.AppMetrics +} + +var _ ManagerCredentials = (*PocketIdCredentials)(nil) + +func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) { + return JWTToken{}, nil +} diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go new file mode 100644 index 000000000..49075a0d3 --- /dev/null +++ b/management/server/idp/pocketid_test.go @@ -0,0 +1,138 @@ +package idp + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + + +func TestNewPocketIdManager(t *testing.T) { + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } + + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } +} + +func TestPocketID_GetUserDataByID(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) +} + +func TestPocketID_GetAccount_WithPagination(t *testing.T) { + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) +} + +func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) +} + +func TestPocketID_CreateUser(t *testing.T) { + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) +} + +func TestPocketID_InviteAndDeleteUser(t *testing.T) { + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) + + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) +} diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 21f11bfce..e9a1c8701 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -136,7 +146,11 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } -func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + +func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index ce632d567..26c338cb6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -3,18 +3,19 @@ package integrated_validator import ( "context" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) - PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer + PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d160e7269..e87043f26 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface diff --git a/management/server/peer.go b/management/server/peer.go index 469b41991..276a06b1a 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } var peer *nbpeer.Peer - var updateAccountPeers bool var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID) - if err != nil { - return err - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers && userID != activity.SystemInitiator { + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -584,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary) network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) } - updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) - if err != nil { - updateAccountPeers = true - } - if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } @@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if updateAccountPeers { - am.BufferUpdateAccountPeers(ctx, accountID) - } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } @@ -1527,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } -// IsPeerInActiveGroup checks if the given peer is part of a group that is used -// in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { - peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) - if err != nil { - return false, err - } - return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction -} - // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 42b3244ae..fd795b926 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) // + peerShouldReceiveUpdate(t, updMsg) // close(done) }() @@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index f4ad59052..342ab50c3 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -467,6 +467,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index f25603a00..6bb0a1a96 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1040,6 +1040,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1127,6 +1130,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 696c44723..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -10,7 +10,6 @@ import ( "net/http" // nolint:gosec _ "net/http/pprof" - "net/netip" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -63,10 +62,10 @@ var ( Use: "run", Short: "start NetBird Signal Server daemon", SilenceUsage: true, - PreRun: func(cmd *cobra.Command, args []string) { + PreRunE: func(cmd *cobra.Command, args []string) error { err := util.InitLog(logLevel, logFile) if err != nil { - log.Fatalf("failed initializing log %v", err) + return fmt.Errorf("failed initializing log: %w", err) } flag.Parse() @@ -87,13 +86,15 @@ var ( signalPort = 80 } } + + return nil }, RunE: func(cmd *cobra.Command, args []string) error { flag.Parse() startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -131,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -139,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -201,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -210,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) { @@ -254,7 +256,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h } func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { - wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) + wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go index 977440a60..ffb622200 100644 --- a/util/wsproxy/server/proxy.go +++ b/util/wsproxy/server/proxy.go @@ -2,42 +2,41 @@ package server import ( "context" - "errors" "io" "net" "net/http" - "net/netip" "sync" "time" "github.com/coder/websocket" log "github.com/sirupsen/logrus" + "golang.org/x/net/http2" "github.com/netbirdio/netbird/util/wsproxy" ) const ( - dialTimeout = 10 * time.Second - bufferSize = 32 * 1024 + bufferSize = 32 * 1024 + ioTimeout = 5 * time.Second ) // Config contains the configuration for the WebSocket proxy. type Config struct { - LocalGRPCAddr netip.AddrPort + Handler http.Handler Path string MetricsRecorder MetricsRecorder } -// Proxy handles WebSocket to TCP proxying for gRPC connections. +// Proxy handles WebSocket to gRPC handler proxying. type Proxy struct { config Config metrics MetricsRecorder } // New creates a new WebSocket proxy instance with optional configuration -func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { +func New(handler http.Handler, opts ...Option) *Proxy { config := Config{ - LocalGRPCAddr: localGRPCAddr, + Handler: handler, Path: wsproxy.ProxyPath, MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op } @@ -63,7 +62,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { p.metrics.RecordConnection(ctx) defer p.metrics.RecordDisconnection(ctx) - log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) + log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr) acceptOptions := &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, } @@ -75,71 +74,41 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { return } defer func() { - if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { - log.Debugf("Failed to close WebSocket: %v", err) - } + _ = wsConn.Close(websocket.StatusNormalClosure, "") }() - log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) - tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout) - if err != nil { - p.metrics.RecordError(ctx, "tcp_dial_failed") - log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err) - if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil { - log.Debugf("Failed to close WebSocket after connection failure: %v", err) - } - return - } + clientConn, serverConn := net.Pipe() defer func() { - if err := tcpConn.Close(); err != nil { - log.Debugf("Failed to close TCP connection: %v", err) - } + _ = clientConn.Close() + _ = serverConn.Close() }() - log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) + log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr) - p.proxyData(ctx, wsConn, tcpConn) + go func() { + (&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{ + Context: ctx, + Handler: p.config.Handler, + }) + }() + + p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr) } -func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { proxyCtx, cancel := context.WithCancel(ctx) defer cancel() var wg sync.WaitGroup wg.Add(2) - go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) - go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) + go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr) + go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr) - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - select { - case <-done: - log.Tracef("Proxy data transfer completed, both goroutines terminated") - case <-proxyCtx.Done(): - log.Tracef("Proxy data transfer cancelled, forcing connection closure") - - if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil { - log.Tracef("Error closing WebSocket during cancellation: %v", err) - } - if err := tcpConn.Close(); err != nil { - log.Tracef("Error closing TCP connection during cancellation: %v", err) - } - - select { - case <-done: - log.Tracef("Goroutines terminated after forced connection closure") - case <-time.After(2 * time.Second): - log.Tracef("Goroutines did not terminate within timeout after connection closure") - } - } + wg.Wait() } -func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { defer wg.Done() defer cancel() @@ -148,80 +117,73 @@ func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync if err != nil { switch { case ctx.Err() != nil: - log.Debugf("wsToTCP goroutine terminating due to context cancellation") - case websocket.CloseStatus(err) == websocket.StatusNormalClosure: - log.Debugf("WebSocket closed normally") + log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr) + case websocket.CloseStatus(err) != -1: + log.Debugf("WebSocket from %s disconnected", clientAddr) default: p.metrics.RecordError(ctx, "websocket_read_error") - log.Errorf("WebSocket read error: %v", err) + log.Debugf("WebSocket read error from %s: %v", clientAddr, err) } return } if msgType != websocket.MessageBinary { - log.Warnf("Unexpected WebSocket message type: %v", msgType) + log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType) continue } if ctx.Err() != nil { - log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") + log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write") return } - if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Debugf("Failed to set TCP write deadline: %v", err) + if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil { + log.Debugf("Failed to set pipe write deadline: %v", err) } - n, err := tcpConn.Write(data) + n, err := pipeConn.Write(data) if err != nil { - p.metrics.RecordError(ctx, "tcp_write_error") - log.Errorf("TCP write error: %v", err) + p.metrics.RecordError(ctx, "pipe_write_error") + log.Warnf("Pipe write error for %s: %v", clientAddr, err) return } - p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) + p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n)) } } -func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { defer wg.Done() defer cancel() buf := make([]byte, bufferSize) for { - if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Debugf("Failed to set TCP read deadline: %v", err) - } - n, err := tcpConn.Read(buf) - + n, err := pipeConn.Read(buf) if err != nil { if ctx.Err() != nil { - log.Tracef("tcpToWS goroutine terminating due to context cancellation") + log.Tracef("pipeToWS goroutine terminating due to context cancellation") return } - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - continue - } - if err != io.EOF { - log.Errorf("TCP read error: %v", err) + log.Debugf("Pipe read error for %s: %v", clientAddr, err) } return } if ctx.Err() != nil { - log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") + log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write") return } - if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { - p.metrics.RecordError(ctx, "websocket_write_error") - log.Errorf("WebSocket write error: %v", err) - return - } + if n > 0 { + if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { + p.metrics.RecordError(ctx, "websocket_write_error") + log.Warnf("WebSocket write error for %s: %v", clientAddr, err) + return + } - p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) + p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n)) + } } } diff --git a/version/url_windows.go b/version/url_windows.go index f2055b109..14fdb7ae6 100644 --- a/version/url_windows.go +++ b/version/url_windows.go @@ -1,9 +1,13 @@ package version -import "golang.org/x/sys/windows/registry" +import ( + "golang.org/x/sys/windows/registry" + "runtime" +) const ( urlWinExe = "https://pkgs.netbird.io/windows/x64" + urlWinExeArm = "https://pkgs.netbird.io/windows/arm64" ) var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird" @@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne // DownloadUrl return with the proper download link func DownloadUrl() string { _, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE) - if err == nil { - return urlWinExe - } else { + if err != nil { return downloadURL } + + url := urlWinExe + if runtime.GOARCH == "arm64" { + url = urlWinExeArm + } + + return url }