diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..d53c5f06b 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/cmd/login.go b/client/cmd/login.go index 3ac211805..40b55f858 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "os/exec" "os/user" "runtime" "strings" @@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := open.Run(verificationURIComplete); err != nil { + if err := openBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } +// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func openBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} + // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index ed8a7403b..d78372c9e 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return "" } - // Include action in the ipset name to prevent squashing rules with different actions actionSuffix := "" if action == firewall.ActionDrop { actionSuffix = "-drop" diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 5ca950297..965decc73 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -29,11 +29,6 @@ type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } -type protoMatch struct { - ips map[string]int - policyID []byte -} - // DefaultManager uses firewall manager to handle type DefaultManager struct { firewall firewall.Manager @@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout } func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { - rules, squashedProtocols := d.squashAcceptRules(networkMap) + rules := networkMap.FirewallRules enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig.SshEnabled - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - enableSSH = enableSSH && !ok - } - if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { - enableSSH = enableSSH && !ok - } - // if TCP protocol rules not squashed and SSH enabled - // we add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP 22 port). + // If SSH enabled, add default firewall rule which accepts connection to any peer + // in the network by SSH (TCP port defined by ssh.DefaultSSHPort). if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", @@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID( return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } -// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type -// to all peers in the network map to one rule which just accepts that type of the traffic. -// -// NOTE: It will not squash two rules for same protocol if one covers all peers in the network, -// but other has port definitions or has drop policy. -func (d *DefaultManager) squashAcceptRules( - networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { - totalIPs := 0 - for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { - for range p.AllowedIps { - totalIPs++ - } - } - - in := map[mgmProto.RuleProtocol]*protoMatch{} - out := map[mgmProto.RuleProtocol]*protoMatch{} - - // trace which type of protocols was squashed - squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} - - // this function we use to do calculation, can we squash the rules by protocol or not. - // We summ amount of Peers IP for given protocol we found in original rules list. - // But we zeroed the IP's for protocol if: - // 1. Any of the rule has DROP action type. - // 2. Any of rule contains Port. - // - // We zeroed this to notify squash function that this protocol can't be squashed. - addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { - hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || - r.Port != "" || !portInfoEmpty(r.PortInfo) - - if hasPortRestrictions { - // Don't squash rules with port restrictions - protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} - return - } - - if _, ok := protocols[r.Protocol]; !ok { - protocols[r.Protocol] = &protoMatch{ - ips: map[string]int{}, - // store the first encountered PolicyID for this protocol - policyID: r.PolicyID, - } - } - - // special case, when we receive this all network IP address - // it means that rules for that protocol was already optimized on the - // management side - if r.PeerIP == "0.0.0.0" { - squashedRules = append(squashedRules, r) - squashedProtocols[r.Protocol] = struct{}{} - return - } - - ipset := protocols[r.Protocol].ips - - if _, ok := ipset[r.PeerIP]; ok { - return - } - ipset[r.PeerIP] = i - } - - for i, r := range networkMap.FirewallRules { - // calculate squash for different directions - if r.Direction == mgmProto.RuleDirection_IN { - addRuleToCalculationMap(i, r, in) - } else { - addRuleToCalculationMap(i, r, out) - } - } - - // order of squashing by protocol is important - // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.RuleProtocol{ - mgmProto.RuleProtocol_ALL, - mgmProto.RuleProtocol_ICMP, - mgmProto.RuleProtocol_TCP, - mgmProto.RuleProtocol_UDP, - } - - squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { - for _, protocol := range protocolOrders { - match, ok := matches[protocol] - if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 { - // don't squash if : - // 1. Rules not cover all peers in the network - // 2. Rules cover only one peer in the network. - continue - } - - // add special rule 0.0.0.0 which allows all IP's in our firewall implementations - squashedRules = append(squashedRules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: direction, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: protocol, - PolicyID: match.policyID, - }) - squashedProtocols[protocol] = struct{}{} - - if protocol == mgmProto.RuleProtocol_ALL { - // if we have ALL traffic type squashed rule - // it allows all other type of traffic, so we can stop processing - break - } - } - } - - squash(in, mgmProto.RuleDirection_IN) - squash(out, mgmProto.RuleDirection_OUT) - - // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - return squashedRules, squashedProtocols - } - - if len(squashedRules) == 0 { - return networkMap.FirewallRules, squashedProtocols - } - - var rules []*mgmProto.FirewallRule - // filter out rules which was squashed from final list - // if we also have other not squashed rules. - for i, r := range networkMap.FirewallRules { - if _, ok := squashedProtocols[r.Protocol]; ok { - if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } - } - rules = append(rules, r) - } - - return append(rules, squashedRules...), squashedProtocols -} - // getRuleGroupingSelector takes all rule properties except IP address to build selector func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 664476ef4..daf4979ce 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) { }) } -func TestDefaultManagerSquashRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, 2, len(rules)) - - r := rules[0] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) - - r = rules[1] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) -} - -func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, len(networkMap.FirewallRules), len(rules)) -} - -func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) { - tests := []struct { - name string - rules []*mgmProto.FirewallRule - expectedCount int - description string - }{ - { - name: "should not squash rules with port ranges", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with port ranges should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with specific ports", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with specific ports should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with legacy port field", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - }, - expectedCount: 4, - description: "Rules with legacy port field should not be squashed", - }, - { - name: "should not squash rules with DROP action", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "Rules with DROP action should not be squashed", - }, - { - name: "should squash rules without port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 1, - description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule", - }, - { - name: "mixed rules should not squash protocol with port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "TCP should not be squashed because one rule has port restrictions", - }, - { - name: "should squash UDP but not TCP when TCP has port restrictions", - rules: []*mgmProto.FirewallRule{ - // TCP rules with port restrictions - should NOT be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - // UDP rules without port restrictions - SHOULD be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0) - description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: tt.rules, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - - assert.Equal(t, tt.expectedCount, len(rules), tt.description) - - // For squashed rules, verify we get the expected 0.0.0.0 rule - if tt.expectedCount == 1 { - assert.Equal(t, "0.0.0.0", rules[0].PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action) - } - }) - } -} - func TestPortInfoEmpty(t *testing.T) { tests := []struct { name string diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index ec920c5f3..442f54e71 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) @@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..d590dba0d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -106,7 +106,7 @@ type DefaultManager struct { func NewManager(config ManagerConfig) *DefaultManager { mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(config.WGInterface, notifier) + sysOps := systemops.New(config.WGInterface, notifier) if runtime.GOOS == "windows" && config.WGInterface != nil { nbnet.SetVPNInterfaceName(config.WGInterface.Name()) diff --git a/client/internal/routemanager/systemops/flush_nonbsd.go b/client/internal/routemanager/systemops/flush_nonbsd.go new file mode 100644 index 000000000..f1c45d6cf --- /dev/null +++ b/client/internal/routemanager/systemops/flush_nonbsd.go @@ -0,0 +1,8 @@ +//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd) + +package systemops + +// FlushMarkedRoutes is a no-op on non-BSD platforms. +func (r *SysOps) FlushMarkedRoutes() error { + return nil +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 8e158711e..e0d045b07 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData((*ExclusionCounter)(s)) + sysOps := New(nil, nil) + sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable) + sysOps.refCounter.LoadData((*ExclusionCounter)(s)) - return sysops.refCounter.Flush() + return sysOps.refCounter.Flush() } func (s *ShutdownState) MarshalJSON() ([]byte, error) { diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8da138117..c0ca21d22 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -83,7 +83,7 @@ type SysOps struct { localSubnetsCacheTime time.Time } -func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { +func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 0d892c162..ec4fc406e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) { _, intf = setupDummyInterface(t) nexthop = Nexthop{netip.Addr{}, intf} - r := NewSysOps(nil, nil) + r := New(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin nexthop := Nexthop{netip.Addr{}, netIntf} - r := NewSysOps(nil, nil) + r := New(nil, nil) err = r.addToRouteTable(prefix, nexthop) require.NoError(t, err, "Failed to add route to table") diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..d9b109beb 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -7,19 +7,39 @@ import ( "fmt" "net" "net/netip" + "os" "strconv" "syscall" "time" "unsafe" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/net/route" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) +const ( + envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" +) + +var routeProtoFlag int + +func init() { + switch os.Getenv(envRouteProtoFlag) { + case "2": + routeProtoFlag = unix.RTF_PROTO2 + case "3": + routeProtoFlag = unix.RTF_PROTO3 + default: + routeProtoFlag = unix.RTF_PROTO1 + } +} + func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +func (r *SysOps) FlushMarkedRoutes() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + flushedCount := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + + routeInfo, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("Skipping route flush: %v", err) + continue + } + + if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() { + continue + } + + nexthop := Nexthop{ + IP: routeInfo.Gw, + Intf: routeInfo.Interface, + } + + if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err)) + continue + } + + flushedCount++ + log.Debugf("Flushed marked route: %s", routeInfo.Dst) + } + + if flushedCount > 0 { + log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount) + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func( func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { msg = &route.RouteMessage{ Type: action, - Flags: unix.RTF_UP, + Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, Seq: r.getSeq(), } diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..2c9e46290 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("state file does not exist") + log.Debugf("state file %s does not exist", m.filePath) return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} diff --git a/client/server/state.go b/client/server/state.go index 107f55154..1cf85cd37 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -10,7 +10,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error { merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) } + // clean up any remaining routes independently of the state file + if !nbnet.AdvancedRouting() { + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/status/status.go b/client/status/status.go index db5b7dc0b..5e4fcd8dc 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -205,15 +205,18 @@ func mapPeers( localICEEndpoint := "" remoteICEEndpoint := "" relayServerAddress := "" - connType := "P2P" + connType := "-" lastHandshake := time.Time{} transferReceived := int64(0) transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if pbPeerState.Relayed { - connType = "Relayed" + if isPeerConnected { + connType = "P2P" + if pbPeerState.Relayed { + connType = "Relayed" + } } if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7c2000a9d..0043f228e 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -31,7 +31,6 @@ import ( "fyne.io/systray" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { } func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := open.Run(loginResp.VerificationURIComplete) + err := openURL(loginResp.VerificationURIComplete) if err != nil { log.Errorf("opening the verification uri in the browser failed: %v", err) return err @@ -1487,6 +1486,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } func openURL(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + var err error switch runtime.GOOS { case "windows": diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..bf9839dda 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } diff --git a/go.mod b/go.mod index cb88f92d3..2f76c0766 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f + github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 6b0b298a7..ce68ed99e 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index bc326cd7e..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index daec4ef6f..209a20065 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 0c493f07d..db377865a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -109,7 +109,7 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..df89c616c 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 741f03f18..bdf56db6e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,9 +7,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 251c04273..e9a1c8701 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index be05c2527..26c338cb6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -15,6 +15,7 @@ type IntegratedValidator interface { PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2d691ba03..8baffa58b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -190,17 +190,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..32538933a 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go index 6f66b787f..5fb757b46 100644 --- a/management/server/types/networkmap.go +++ b/management/server/types/networkmap.go @@ -136,9 +136,8 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone - if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -148,14 +147,6 @@ func (a *Account) GetPeerNetworkMap( dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) } - // nm := GetNetworkMap() - // nm.Peers = peersToConnectIncludingRouters - // nm.Network = a.Network.Copy() - // nm.Routes = slices.Concat(networkResourcesRoutes, routesUpdate) - // nm.DNSConfig = dnsUpdate - // nm.OfflinePeers = expiredPeers - // nm.FirewallRules = firewallRules - // nm.RoutesFirewallRules = slices.Concat(networkResourcesFirewallRules, routesFirewallRules) nm := &NetworkMap{ Peers: peersToConnectIncludingRouters, Network: a.Network.Copy(), @@ -929,7 +920,7 @@ func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -940,6 +931,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 93578b1ae..4a5454002 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -463,6 +463,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3dbb32ef6..9611d26d6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1037,6 +1037,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1124,6 +1127,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {