diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index e1e1ff236..b584f0ff6 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 @@ -79,9 +79,6 @@ jobs: - name: check git status run: git --no-pager diff --exit-code - - name: Generate Iface Test bin - run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/ - - name: Generate Shared Sock Test bin run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock @@ -98,7 +95,7 @@ jobs: run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal - name: Generate Peer Test bin - run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/... + run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/ - run: chmod +x *testing.bin @@ -106,7 +103,7 @@ jobs: run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 - name: Run Iface tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... - name: Run RouteManager tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 2d743f790..dacb1922b 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif + ignore_words_list: erro,clienta,hastable,iif,groupd skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1b85ec7ef..14e383a27 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.15" + SIGN_PIPE_VER: "v0.0.16" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" diff --git a/client/cmd/status.go b/client/cmd/status.go index ed3daa2b5..6db52a677 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { statusEval := false ipEval := false - nameEval := false + nameEval := true if statusFilter != "" { lowerStatusFilter := strings.ToLower(statusFilter) @@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { if len(prefixNamesFilter) > 0 { for prefixNameFilter := range prefixNamesFilterMap { - if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { - nameEval = true + if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { + nameEval = false break } } + } else { + nameEval = false } return statusEval || ipEval || nameEval diff --git a/client/firewall/create.go b/client/firewall/create.go index 86ce94cea..9466f4b4d 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -3,7 +3,6 @@ package firewall import ( - "context" "fmt" "runtime" @@ -11,10 +10,11 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) // NewFirewall creates a firewall manager instance -func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 92deb63dc..076d08ec2 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -3,7 +3,7 @@ package firewall import ( - "context" + "errors" "fmt" "os" @@ -15,6 +15,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -32,54 +33,65 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - var fm firewall.Manager - var errFw error + fm, err := createNativeFirewall(iface, stateManager) + if !iface.IsUserspaceBind() { + return fm, err + } + + if err != nil { + log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) + } + return createUserspaceFirewall(iface, fm) +} + +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { + fm, err := createFW(iface) + if err != nil { + return nil, fmt.Errorf("create firewall: %s", err) + } + + if err = fm.Init(stateManager); err != nil { + return nil, fmt.Errorf("init firewall: %s", err) + } + + return fm, nil +} + +func createFW(iface IFaceMapper) (firewall.Manager, error) { switch check() { case IPTABLES: log.Info("creating an iptables firewall manager") - fm, errFw = nbiptables.Create(context, iface) - if errFw != nil { - log.Errorf("failed to create iptables manager: %s", errFw) - } + return nbiptables.Create(iface) case NFTABLES: log.Info("creating an nftables firewall manager") - fm, errFw = nbnftables.Create(context, iface) - if errFw != nil { - log.Errorf("failed to create nftables manager: %s", errFw) - } + return nbnftables.Create(iface) default: - errFw = fmt.Errorf("no firewall manager found") log.Info("no firewall manager found, trying to use userspace packet filtering firewall") + return nil, errors.New("no firewall manager found") + } +} + +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { + var errUsp error + if fm != nil { + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) + } else { + fm, errUsp = uspfilter.Create(iface) } - if iface.IsUserspaceBind() { - var errUsp error - if errFw == nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) - } else { - fm, errUsp = uspfilter.Create(iface) - } - if errUsp != nil { - log.Debugf("failed to create userspace filtering firewall: %s", errUsp) - return nil, errUsp - } - - if err := fm.AllowNetbird(); err != nil { - log.Errorf("failed to allow netbird interface traffic: %v", err) - } - return fm, nil + if errUsp != nil { + return nil, fmt.Errorf("create userspace firewall: %s", errUsp) } - if errFw != nil { - return nil, errFw + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) } - return fm, nil } diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c271e592d..5cd69245b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -22,6 +23,8 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type aclEntries map[string][][]string + type entry struct { spec []string position int @@ -32,9 +35,11 @@ type aclManager struct { wgIface iFaceMapper routingFwChainName string - entries map[string][][]string + entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + + stateManager *statemanager.Manager } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi ipsetStore: newIpsetStore(), } - err := ipset.Init() - if err != nil { - return nil, fmt.Errorf("failed to init ipset: %w", err) + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) } + return m, nil +} + +func (m *aclManager) init(stateManager *statemanager.Manager) error { + m.stateManager = stateManager + m.seedInitialEntries() m.seedInitialOptionalEntries() - err = m.cleanChains() - if err != nil { - return nil, err + if err := m.cleanChains(); err != nil { + return fmt.Errorf("clean chains: %w", err) } - err = m.createDefaultChains() - if err != nil { - return nil, err + if err := m.createDefaultChains(); err != nil { + return fmt.Errorf("create default chains: %w", err) } - return m, nil + + m.updateState() + + return nil } func (m *aclManager) AddPeerFiltering( @@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering( chain: chain, } + m.updateState() + return []firewall.Rule{rule}, nil } @@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { } } - err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) - if err != nil { - log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) + if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { + return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) } - return err + + m.updateState() + + return nil } func (m *aclManager) Reset() error { - return m.cleanChains() + if err := m.cleanChains(); err != nil { + return fmt.Errorf("clean chains: %w", err) + } + + m.updateState() + + return nil } // todo write less destructive cleanup mechanism @@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } +func (m *aclManager) updateState() { + if m.stateManager == nil { + return + } + + var currentState *ShutdownState + if existing := m.stateManager.GetState(currentState); existing != nil { + if existingState, ok := existing.(*ShutdownState); ok { + currentState = existingState + } + } + if currentState == nil { + currentState = &ShutdownState{} + } + + currentState.Lock() + defer currentState.Unlock() + + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + + if err := m.stateManager.UpdateState(currentState); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + // filterRuleSpecs returns the specs of a filtering rule func filterRuleSpecs( ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 94bd2fccf..a59bd2c60 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -8,10 +8,13 @@ import ( "sync" "github.com/coreos/go-iptables/iptables" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/statemanager" ) // Manager of iptables firewall @@ -33,10 +36,10 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { - return nil, fmt.Errorf("iptables is not installed in the system or not supported") + return nil, fmt.Errorf("init iptables: %w", err) } m := &Manager{ @@ -44,20 +47,49 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(context, iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface) if err != nil { - log.Debugf("failed to initialize route related chains: %s", err) - return nil, err + return nil, fmt.Errorf("create router: %w", err) } + m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) if err != nil { - log.Debugf("failed to initialize ACL manager: %s", err) - return nil, err + return nil, fmt.Errorf("create acl manager: %w", err) } return m, nil } +func (m *Manager) Init(stateManager *statemanager.Manager) error { + state := &ShutdownState{ + InterfaceState: &InterfaceState{ + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + UserspaceBind: m.wgIface.IsUserspaceBind(), + }, + } + stateManager.RegisterState(state) + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update state: %v", err) + } + + if err := m.router.init(stateManager); err != nil { + return fmt.Errorf("router init: %w", err) + } + + if err := m.aclMgr.init(stateManager); err != nil { + // TODO: cleanup router + return fmt.Errorf("acl manager init: %w", err) + } + + // persist early to ensure cleanup of chains + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -133,20 +165,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - errAcl := m.aclMgr.Reset() - if errAcl != nil { - log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl) + var merr *multierror.Error + + if err := m.aclMgr.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } - errMgr := m.router.Reset() - if errMgr != nil { - log.Errorf("failed to clean up router rules from firewall: %s", errMgr) - return errMgr + if err := m.router.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) } - return errAcl + + // attempt to delete state only if all other operations succeeded + if merr == nil { + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 498d8f58b..ebdb83137 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -1,7 +1,6 @@ package iptables import ( - "context" "fmt" "net" "testing" @@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(context.Background(), ifaceMock) + manager, err := Create(ifaceMock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) { _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) @@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) { }) t.Run("reset check", func(t *testing.T) { - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") }) } @@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 129323928..9b75640b4 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -3,7 +3,6 @@ package iptables import ( - "context" "fmt" "net/netip" "strconv" @@ -18,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -48,28 +48,31 @@ type routeFilteringRuleParams struct { SetName string } +type routeRules map[string][]string + +type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] + type router struct { - ctx context.Context - stop context.CancelFunc iptablesClient *iptables.IPTables - rules map[string][]string - ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] + rules routeRules + ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + + stateManager *statemanager.Manager } -func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { r := &router{ - ctx: ctx, - stop: cancel, iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, } r.ipsetCounter = refcounter.New( - r.createIpSet, + func(name string, sources []netip.Prefix) (struct{}, error) { + return struct{}{}, r.createIpSet(name, sources) + }, func(name string, _ struct{}) error { return r.deleteIpSet(name) }, @@ -79,16 +82,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI return nil, fmt.Errorf("init ipset: %w", err) } - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("cleanup routing rules: %s", err) - return nil, err + return r, nil +} + +func (r *router) init(stateManager *statemanager.Manager) error { + r.stateManager = stateManager + + if err := r.cleanUpDefaultForwardRules(); err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.createContainers() - if err != nil { - log.Errorf("create containers for route: %s", err) + + if err := r.createContainers(); err != nil { + return fmt.Errorf("create containers: %w", err) } - return r, err + + r.updateState() + + return nil } func (r *router) AddRouteFiltering( @@ -129,6 +139,8 @@ func (r *router) AddRouteFiltering( r.rules[string(ruleKey)] = rule + r.updateState() + return ruleKey, nil } @@ -152,6 +164,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { log.Debugf("route rule %s not found", ruleKey) } + r.updateState() + return nil } @@ -164,18 +178,18 @@ func (r *router) findSetNameInRule(rule []string) string { return "" } -func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { +func (r *router) createIpSet(setName string, sources []netip.Prefix) error { if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { - return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) + return fmt.Errorf("create set %s: %w", setName, err) } for _, prefix := range sources { if err := ipset.AddPrefix(setName, prefix); err != nil { - return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) + return fmt.Errorf("add element to set %s: %w", setName, err) } } - return struct{}{}, nil + return nil } func (r *router) deleteIpSet(setName string) error { @@ -206,6 +220,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { return fmt.Errorf("add inverse nat rule: %w", err) } + r.updateState() + return nil } @@ -223,6 +239,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy routing rule: %w", err) } + r.updateState() + return nil } @@ -278,8 +296,13 @@ func (r *router) RemoveAllLegacyRouteRules() error { } if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } else { + delete(r.rules, k) } } + + r.updateState() + return nberrors.FormatErrorOrNil(merr) } @@ -294,6 +317,8 @@ func (r *router) Reset() error { merr = multierror.Append(merr, err) } + r.updateState() + return nberrors.FormatErrorOrNil(merr) } @@ -431,6 +456,32 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { return nil } +func (r *router) updateState() { + if r.stateManager == nil { + return + } + + var currentState *ShutdownState + if existing := r.stateManager.GetState(currentState); existing != nil { + if existingState, ok := existing.(*ShutdownState); ok { + currentState = existingState + } + } + if currentState == nil { + currentState = &ShutdownState{} + } + + currentState.Lock() + defer currentState.Unlock() + + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + + if err := r.stateManager.UpdateState(currentState); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { intdir := "-i" lointdir := "-o" diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 6cede09e2..2d821a9db 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -3,7 +3,6 @@ package iptables import ( - "context" "net/netip" "os/exec" "testing" @@ -30,8 +29,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") + require.NoError(t, manager.init(nil)) defer func() { _ = manager.Reset() @@ -74,8 +74,9 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") + require.NoError(t, manager.init(nil)) defer func() { err := manager.Reset() @@ -132,8 +133,9 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") + require.NoError(t, manager.init(nil)) defer func() { _ = manager.Reset() }() @@ -183,8 +185,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(context.Background(), iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "Failed to create router manager") + require.NoError(t, r.init(nil)) defer func() { err := r.Reset() diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go index a9470c9ac..bfd08bee2 100644 --- a/client/firewall/iptables/rulestore_linux.go +++ b/client/firewall/iptables/rulestore_linux.go @@ -1,14 +1,16 @@ package iptables +import "encoding/json" + type ipList struct { ips map[string]struct{} } -func newIpList(ip string) ipList { +func newIpList(ip string) *ipList { ips := make(map[string]struct{}) ips[ip] = struct{}{} - return ipList{ + return &ipList{ ips: ips, } } @@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) { s.ips[ip] = struct{}{} } +// MarshalJSON implements json.Marshaler +func (s *ipList) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + IPs map[string]struct{} `json:"ips"` + }{ + IPs: s.ips, + }) +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *ipList) UnmarshalJSON(data []byte) error { + temp := struct { + IPs map[string]struct{} `json:"ips"` + }{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + s.ips = temp.IPs + return nil +} + type ipsetStore struct { - ipsets map[string]ipList // ipsetName -> ruleset + ipsets map[string]*ipList } func newIpsetStore() *ipsetStore { return &ipsetStore{ - ipsets: make(map[string]ipList), + ipsets: make(map[string]*ipList), } } -func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) { +func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) { r, ok := s.ipsets[ipsetName] return r, ok } -func (s *ipsetStore) addIpList(ipsetName string, list ipList) { +func (s *ipsetStore) addIpList(ipsetName string, list *ipList) { s.ipsets[ipsetName] = list } func (s *ipsetStore) deleteIpset(ipsetName string) { - s.ipsets[ipsetName] = ipList{} delete(s.ipsets, ipsetName) } @@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string { } return names } + +// MarshalJSON implements json.Marshaler +func (s *ipsetStore) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + IPSets map[string]*ipList `json:"ipsets"` + }{ + IPSets: s.ipsets, + }) +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *ipsetStore) UnmarshalJSON(data []byte) error { + temp := struct { + IPSets map[string]*ipList `json:"ipsets"` + }{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + s.ipsets = temp.IPSets + return nil +} diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go new file mode 100644 index 000000000..44b8340ba --- /dev/null +++ b/client/firewall/iptables/state_linux.go @@ -0,0 +1,70 @@ +package iptables + +import ( + "fmt" + "sync" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +type InterfaceState struct { + NameStr string `json:"name"` + WGAddress iface.WGAddress `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` +} + +func (i *InterfaceState) Name() string { + return i.NameStr +} + +func (i *InterfaceState) Address() device.WGAddress { + return i.WGAddress +} + +func (i *InterfaceState) IsUserspaceBind() bool { + return i.UserspaceBind +} + +type ShutdownState struct { + sync.Mutex + + InterfaceState *InterfaceState `json:"interface_state,omitempty"` + + RouteRules routeRules `json:"route_rules,omitempty"` + RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"` + + ACLEntries aclEntries `json:"acl_entries,omitempty"` + ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` +} + +func (s *ShutdownState) Name() string { + return "iptables_state" +} + +func (s *ShutdownState) Cleanup() error { + ipt, err := Create(s.InterfaceState) + if err != nil { + return fmt.Errorf("create iptables manager: %w", err) + } + + if s.RouteRules != nil { + ipt.router.rules = s.RouteRules + } + if s.RouteIPsetCounter != nil { + ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter) + } + + if s.ACLEntries != nil { + ipt.aclMgr.entries = s.ACLEntries + } + if s.ACLIPsetStore != nil { + ipt.aclMgr.ipsetStore = s.ACLIPsetStore + } + + if err := ipt.Reset(nil); err != nil { + return fmt.Errorf("reset iptables manager: %w", err) + } + + return nil +} diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 556bda0d6..2a40cd9f6 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -10,6 +10,8 @@ import ( "strings" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -52,6 +54,8 @@ const ( // It declares methods which handle actions required by the // Netbird client for ACL and routing functionality type Manager interface { + Init(stateManager *statemanager.Manager) error + // AllowNetbird allows netbird interface traffic AllowNetbird() error @@ -91,7 +95,7 @@ type Manager interface { SetLegacyManagement(legacy bool) error // Reset firewall to the default state - Reset() error + Reset(stateManager *statemanager.Manager) error // Flush the changes to firewall controller Flush() error diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 61434f035..ca7b2e59f 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -17,7 +17,6 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -56,13 +55,6 @@ type AclManager struct { rules map[string]*Rule } -// iFaceMapper defines subset methods of interface required for manager -type iFaceMapper interface { - Name() string - Address() iface.WGAddress - IsUserspaceBind() bool -} - func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) @@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam // overloads netlink with high amount of rules ( > 10000) sConn, err := nftables.New(nftables.AsLasting()) if err != nil { - return nil, err + return nil, fmt.Errorf("create nf conn: %w", err) } - m := &AclManager{ + return &AclManager{ rConn: &nftables.Conn{}, sConn: sConn, wgIface: wgIface, @@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), - } + }, nil +} - err = m.createDefaultChains() - if err != nil { - return nil, err - } - - return m, nil +func (m *AclManager) init(workTable *nftables.Table) error { + m.workTable = workTable + return m.createDefaultChains() } // AddPeerFiltering rule to the firewall diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 01b08bd71..ea8912f27 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -14,6 +14,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -24,6 +26,13 @@ const ( chainNameInput = "INPUT" ) +// iFaceMapper defines subset methods of interface required for manager +type iFaceMapper interface { + Name() string + Address() iface.WGAddress + IsUserspaceBind() bool +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -35,30 +44,68 @@ type Manager struct { } // Create nftables firewall manager -func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, } - workTable, err := m.createWorkTable() - if err != nil { - return nil, err - } + workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} - m.router, err = newRouter(context, workTable, wgIface) + var err error + m.router, err = newRouter(workTable, wgIface) if err != nil { - return nil, err + return nil, fmt.Errorf("create router: %w", err) } m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) if err != nil { - return nil, err + return nil, fmt.Errorf("create acl manager: %w", err) } return m, nil } +// Init nftables firewall manager +func (m *Manager) Init(stateManager *statemanager.Manager) error { + workTable, err := m.createWorkTable() + if err != nil { + return fmt.Errorf("create work table: %w", err) + } + + if err := m.router.init(workTable); err != nil { + return fmt.Errorf("router init: %w", err) + } + + if err := m.aclManager.init(workTable); err != nil { + // TODO: cleanup router + return fmt.Errorf("acl manager init: %w", err) + } + + stateManager.RegisterState(&ShutdownState{}) + + // We only need to record minimal interface state for potential recreation. + // Unlike iptables, which requires tracking individual rules, nftables maintains + // a known state (our netbird table plus a few static rules). This allows for easy + // cleanup using Reset() without needing to store specific rules. + if err := stateManager.UpdateState(&ShutdownState{ + InterfaceState: &InterfaceState{ + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + UserspaceBind: m.wgIface.IsUserspaceBind(), + }, + }); err != nil { + log.Errorf("failed to update state: %v", err) + } + + // persist early + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + + return nil +} + // AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set @@ -183,68 +230,84 @@ func (m *Manager) AllowNetbird() error { // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - oldLegacy := m.router.legacyManagement + return firewall.SetLegacyManagement(m.router, isLegacy) +} - if oldLegacy != isLegacy { - m.router.legacyManagement = isLegacy - log.Debugf("Set legacy management to %v", isLegacy) +// Reset firewall to the default state +func (m *Manager) Reset(stateManager *statemanager.Manager) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if err := m.resetNetbirdInputRules(); err != nil { + return fmt.Errorf("reset netbird input rules: %v", err) } - // client reconnected to a newer mgmt, we need to cleanup the legacy rules - if !isLegacy && oldLegacy { - if err := m.router.RemoveAllLegacyRouteRules(); err != nil { - return fmt.Errorf("remove legacy routing rules: %v", err) - } + if err := m.router.Reset(); err != nil { + return fmt.Errorf("reset router: %v", err) + } - log.Debugf("Legacy routing rules removed") + if err := m.cleanupNetbirdTables(); err != nil { + return fmt.Errorf("cleanup netbird tables: %v", err) + } + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + return fmt.Errorf("delete state: %v", err) } return nil } -// Reset firewall to the default state -func (m *Manager) Reset() error { - m.mutex.Lock() - defer m.mutex.Unlock() - +func (m *Manager) resetNetbirdInputRules() error { chains, err := m.rConn.ListChains() if err != nil { - return fmt.Errorf("list of chains: %w", err) + return fmt.Errorf("list chains: %w", err) } + m.deleteNetbirdInputRules(chains) + + return nil +} + +func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { for _, c := range chains { - // delete Netbird allow input traffic rule if it exists if c.Table.Name == "filter" && c.Name == "INPUT" { rules, err := m.rConn.GetRules(c.Table, c) if err != nil { log.Errorf("get rules for chain %q: %v", c.Name, err) continue } - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } + + m.deleteMatchingRules(rules) + } + } +} + +func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { + for _, r := range rules { + if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { + if err := m.rConn.DelRule(r); err != nil { + log.Errorf("delete rule: %v", err) } } } +} - if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset forward rules: %v", err) - } - +func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { - return fmt.Errorf("list of tables: %w", err) + return fmt.Errorf("list tables: %w", err) } + for _, t := range tables { if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } - - return m.rConn.Flush() + return nil } // Flush rule/chain/set operations from the buffer diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index bbe18ab07..77f4f0306 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,7 +1,6 @@ package nftables import ( - "context" "fmt" "net" "net/netip" @@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), ifaceMock) + manager, err := Create(ifaceMock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) defer func() { - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") time.Sleep(time.Second) }() @@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) { // established rule remains require.Len(t, rules, 1, "expected 1 rules after deletion") - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") } @@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) defer func() { - if err := manager.Reset(); err != nil { + if err := manager.Reset(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 03526fee7..0e7ea71b7 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -2,7 +2,6 @@ package nftables import ( "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -40,8 +39,6 @@ var ( ) type router struct { - ctx context.Context - stop context.CancelFunc conn *nftables.Conn workTable *nftables.Table filterTable *nftables.Table @@ -54,12 +51,8 @@ type router struct { legacyManagement bool } -func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) - +func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { r := &router{ - ctx: ctx, - stop: cancel, conn: &nftables.Conn{}, workTable: workTable, chains: make(map[string]*nftables.Chain), @@ -78,20 +71,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa if errors.Is(err, errFilterTableNotFound) { log.Warnf("table 'filter' not found for forward rules") } else { - return nil, err + return nil, fmt.Errorf("load filter table: %w", err) } } - err = r.removeAcceptForwardRules() - if err != nil { + return r, nil +} + +func (r *router) init(workTable *nftables.Table) error { + r.workTable = workTable + + if err := r.removeAcceptForwardRules(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.createContainers() - if err != nil { - log.Errorf("failed to create containers for route: %s", err) + if err := r.createContainers(); err != nil { + return fmt.Errorf("create containers: %w", err) } - return r, err + + return nil } // Reset cleans existing nftables default forward rules from the system @@ -553,7 +551,10 @@ func (r *router) RemoveAllLegacyRouteRules() error { } if err := r.conn.DelRule(rule); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } else { + delete(r.rules, k) } + } return nberrors.FormatErrorOrNil(merr) } diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index c07111b4e..19ed48991 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -3,7 +3,6 @@ package nftables import ( - "context" "encoding/binary" "net/netip" "os/exec" @@ -40,8 +39,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table, ifaceMock) + manager, err := newRouter(table, ifaceMock) require.NoError(t, err, "failed to create router") + require.NoError(t, manager.init(table)) nftablesTestingClient := &nftables.Conn{} @@ -142,8 +142,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table, ifaceMock) + manager, err := newRouter(table, ifaceMock) require.NoError(t, err, "failed to create router") + require.NoError(t, manager.init(table)) nftablesTestingClient := &nftables.Conn{} @@ -210,8 +211,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(context.Background(), workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock) require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) defer func(r *router) { require.NoError(t, r.Reset(), "Failed to reset rules") @@ -376,8 +378,9 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(context.Background(), workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock) require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) defer func() { require.NoError(t, r.Reset(), "Failed to reset router") diff --git a/client/firewall/nftables/state.go b/client/firewall/nftables/state.go new file mode 100644 index 000000000..7027fe987 --- /dev/null +++ b/client/firewall/nftables/state.go @@ -0,0 +1 @@ +package nftables diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go new file mode 100644 index 000000000..a68c8b8b8 --- /dev/null +++ b/client/firewall/nftables/state_linux.go @@ -0,0 +1,47 @@ +package nftables + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +type InterfaceState struct { + NameStr string `json:"name"` + WGAddress iface.WGAddress `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` +} + +func (i *InterfaceState) Name() string { + return i.NameStr +} + +func (i *InterfaceState) Address() device.WGAddress { + return i.WGAddress +} + +func (i *InterfaceState) IsUserspaceBind() bool { + return i.UserspaceBind +} + +type ShutdownState struct { + InterfaceState *InterfaceState `json:"interface_state,omitempty"` +} + +func (s *ShutdownState) Name() string { + return "nftables_state" +} + +func (s *ShutdownState) Cleanup() error { + nft, err := Create(s.InterfaceState) + if err != nil { + return fmt.Errorf("create nftables manager: %w", err) + } + + if err := nft.Reset(nil); err != nil { + return fmt.Errorf("reset nftables manager: %w", err) + } + + return nil +} diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 2275dad39..cefc81a3c 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,8 +2,10 @@ package uspfilter +import "github.com/netbirdio/netbird/client/internal/statemanager" + // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -11,7 +13,7 @@ func (m *Manager) Reset() error { m.incomingRules = make(map[string]RuleSet) if m.nativeFirewall != nil { - return m.nativeFirewall.Reset() + return m.nativeFirewall.Reset(stateManager) } return nil } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 34274564f..d3732301e 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -6,6 +6,8 @@ import ( "syscall" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) type action string @@ -17,7 +19,7 @@ const ( ) // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 0e3ee9799..af5dc6733 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const layerTypeAll = 0 @@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) { return m, nil } +func (m *Manager) Init(*statemanager.Manager) error { + return nil +} + func (m *Manager) IsServerRouteSupported() bool { if m.nativeFirewall == nil { return false @@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering( return []firewall.Rule{&r}, nil } -func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errRouteNotSupported } @@ -232,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { } // SetLegacyManagement doesn't need to be implemented for this manager -func (m *Manager) SetLegacyManagement(_ bool) error { - return nil +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.SetLegacyManagement(isLegacy) } // Flush doesn't need to be implemented for this manager diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index c188deea4..d7c93cb7f 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) { return } - err = m.Reset() + err = m.Reset(nil) if err != nil { t.Errorf("failed to reset Manager: %v", err) return @@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if err = m.Reset(); err != nil { + if err = m.Reset(nil); err != nil { t.Errorf("failed to reset Manager: %v", err) return } @@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { time.Sleep(time.Second) defer func() { - if err := manager.Reset(); err != nil { + if err := manager.Reset(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) diff --git a/client/iface/bind/bind.go b/client/iface/bind/bind.go deleted file mode 100644 index ba6153cb7..000000000 --- a/client/iface/bind/bind.go +++ /dev/null @@ -1,142 +0,0 @@ -package bind - -import ( - "fmt" - "net" - "runtime" - "sync" - - "github.com/pion/stun/v2" - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" - "golang.org/x/net/ipv4" - wgConn "golang.zx2c4.com/wireguard/conn" -) - -type receiverCreator struct { - iceBind *ICEBind -} - -func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { - return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn) -} - -type ICEBind struct { - *wgConn.StdNetBind - - muUDPMux sync.Mutex - - transportNet transport.Net - udpMux *UniversalUDPMuxDefault - - filterFn FilterFn -} - -func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { - ib := &ICEBind{ - transportNet: transportNet, - filterFn: filterFn, - } - - rc := receiverCreator{ - ib, - } - ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc) - return ib -} - -// GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { - s.muUDPMux.Lock() - defer s.muUDPMux.Unlock() - if s.udpMux == nil { - return nil, fmt.Errorf("ICEBind has not been initialized yet") - } - - return s.udpMux, nil -} - -func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { - s.muUDPMux.Lock() - defer s.muUDPMux.Unlock() - - s.udpMux = NewUniversalUDPMuxDefault( - UniversalUDPMuxParams{ - UDPConn: conn, - Net: s.transportNet, - FilterFn: s.filterFn, - }, - ) - return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - msgs := ipv4MsgsPool.Get().(*[]ipv4.Message) - defer ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - - // todo: handle err - ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) - if ok { - sizes[i] = 0 - } else { - sizes[i] = msg.N - } - - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil - } -} - -func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { - for i := range buffers { - if !stun.IsMessage(buffers[i]) { - continue - } - - msg, err := s.parseSTUNMessage(buffers[i][:n]) - if err != nil { - buffers[i] = []byte{} - return true, err - } - - muxErr := s.udpMux.HandleSTUNMessage(msg, addr) - if muxErr != nil { - log.Warnf("failed to handle STUN packet") - } - - buffers[i] = []byte{} - return true, nil - } - return false, nil -} - -func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) { - msg := &stun.Message{ - Raw: raw, - } - if err := msg.Decode(); err != nil { - return nil, err - } - - return msg, nil -} diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go new file mode 100644 index 000000000..1926ff88f --- /dev/null +++ b/client/iface/bind/endpoint.go @@ -0,0 +1,5 @@ +package bind + +import wgConn "golang.zx2c4.com/wireguard/conn" + +type Endpoint = wgConn.StdNetEndpoint diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go new file mode 100644 index 000000000..a9c25950d --- /dev/null +++ b/client/iface/bind/ice_bind.go @@ -0,0 +1,275 @@ +package bind + +import ( + "fmt" + "net" + "net/netip" + "runtime" + "strings" + "sync" + + "github.com/pion/stun/v2" + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" + "golang.org/x/net/ipv4" + wgConn "golang.zx2c4.com/wireguard/conn" +) + +type RecvMessage struct { + Endpoint *Endpoint + Buffer []byte +} + +type receiverCreator struct { + iceBind *ICEBind +} + +func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { + return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn) +} + +// ICEBind is a bind implementation with two main features: +// 1. filter out STUN messages and handle them +// 2. forward the received packets to the WireGuard interface from the relayed connection +// +// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address +// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to +// use the port because in the Send function the wgConn.Endpoint the port info is not exported. +type ICEBind struct { + *wgConn.StdNetBind + RecvChan chan RecvMessage + + transportNet transport.Net + filterFn FilterFn + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a + // new closed channel. With the closedChanMu we can safely close the channel and create a new one + closedChan chan struct{} + closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. + closed bool + + muUDPMux sync.Mutex + udpMux *UniversalUDPMuxDefault +} + +func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { + b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) + ib := &ICEBind{ + StdNetBind: b, + RecvChan: make(chan RecvMessage, 1), + transportNet: transportNet, + filterFn: filterFn, + endpoints: make(map[netip.Addr]net.Conn), + closedChan: make(chan struct{}), + closed: true, + } + + rc := receiverCreator{ + ib, + } + ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc) + return ib +} + +func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { + s.closed = false + s.closedChanMu.Lock() + s.closedChan = make(chan struct{}) + s.closedChanMu.Unlock() + fns, port, err := s.StdNetBind.Open(uport) + if err != nil { + return nil, 0, err + } + fns = append(fns, s.receiveRelayed) + return fns, port, nil +} + +func (s *ICEBind) Close() error { + if s.closed { + return nil + } + s.closed = true + + close(s.closedChan) + + return s.StdNetBind.Close() +} + +// GetICEMux returns the ICE UDPMux that was created and used by ICEBind +func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { + s.muUDPMux.Lock() + defer s.muUDPMux.Unlock() + if s.udpMux == nil { + return nil, fmt.Errorf("ICEBind has not been initialized yet") + } + + return s.udpMux, nil +} + +func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) { + fakeUDPAddr, err := fakeAddress(peerAddress) + if err != nil { + return nil, err + } + + // force IPv4 + fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4()) + if !ok { + return nil, fmt.Errorf("failed to convert IP to netip.Addr") + } + + b.endpointsMu.Lock() + b.endpoints[fakeAddr] = conn + b.endpointsMu.Unlock() + + return fakeUDPAddr, nil +} + +func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) { + fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4()) + if !ok { + log.Warnf("failed to convert IP to netip.Addr") + return + } + + b.endpointsMu.Lock() + defer b.endpointsMu.Unlock() + delete(b.endpoints, fakeAddr) +} + +func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { + b.endpointsMu.Lock() + conn, ok := b.endpoints[ep.DstIP()] + b.endpointsMu.Unlock() + if !ok { + return b.StdNetBind.Send(bufs, ep) + } + + for _, buf := range bufs { + if _, err := conn.Write(buf); err != nil { + return err + } + } + return nil +} + +func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { + s.muUDPMux.Lock() + defer s.muUDPMux.Unlock() + + s.udpMux = NewUniversalUDPMuxDefault( + UniversalUDPMuxParams{ + UDPConn: conn, + Net: s.transportNet, + FilterFn: s.filterFn, + }, + ) + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + msgs := ipv4MsgsPool.Get().(*[]ipv4.Message) + defer ipv4MsgsPool.Put(msgs) + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + } + var numMsgs int + if runtime.GOOS == "linux" { + numMsgs, err = pc.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + + // todo: handle err + ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) + if ok { + sizes[i] = 0 + } else { + sizes[i] = msg.N + } + + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil + } +} + +func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { + for i := range buffers { + if !stun.IsMessage(buffers[i]) { + continue + } + + msg, err := s.parseSTUNMessage(buffers[i][:n]) + if err != nil { + buffers[i] = []byte{} + return true, err + } + + muxErr := s.udpMux.HandleSTUNMessage(msg, addr) + if muxErr != nil { + log.Warnf("failed to handle STUN packet") + } + + buffers[i] = []byte{} + return true, nil + } + return false, nil +} + +func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) { + msg := &stun.Message{ + Raw: raw, + } + if err := msg.Decode(); err != nil { + return nil, err + } + + return msg, nil +} + +// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the +// WireGuard. Critical part is do not block if the Closed() has been called. +func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + c.closedChanMu.RLock() + defer c.closedChanMu.RUnlock() + + select { + case <-c.closedChan: + return 0, net.ErrClosed + case msg, ok := <-c.RecvChan: + if !ok { + return 0, net.ErrClosed + } + copy(buffs[0], msg.Buffer) + sizes[0] = len(msg.Buffer) + eps[0] = wgConn.Endpoint(msg.Endpoint) + return 1, nil + } +} + +// fakeAddress returns a fake address that is used to as an identifier for the peer. +// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. +func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) { + octets := strings.Split(peerAddress.IP.String(), ".") + if len(octets) != 4 { + return nil, fmt.Errorf("invalid IP format") + } + + newAddr := &net.UDPAddr{ + IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])), + Port: peerAddress.Port, + } + return newAddr, nil +} diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 29e3f409d..fac2ba63d 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -5,7 +5,6 @@ package device import ( "strings" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/device" @@ -31,13 +30,13 @@ type WGTunDevice struct { configurer WGConfigurer } -func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice { +func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { return &WGTunDevice{ address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, tunAdapter: tunAdapter, } } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 03e85a7f1..b5a128bc1 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -6,7 +6,6 @@ import ( "fmt" "os/exec" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -29,14 +28,14 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, } } diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 226e8a2e0..b9591e0b8 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -6,7 +6,6 @@ package device import ( "os" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/device" @@ -30,13 +29,13 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice { +func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, tunFd: tunFd, } } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index 440a1ca19..f5d39e9e0 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -6,7 +6,6 @@ package device import ( "fmt" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" @@ -31,7 +30,7 @@ type TunNetstackDevice struct { configurer WGConfigurer } -func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice { +func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -39,7 +38,7 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, } } diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4175f6556..643d77565 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -7,7 +7,6 @@ import ( "os" "runtime" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -30,7 +29,7 @@ type USPDevice struct { configurer WGConfigurer } -func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice { +func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { log.Infof("using userspace bind mode") checkUser() @@ -41,7 +40,8 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn)} + iceBind: iceBind, + } } func (t *USPDevice) Create() (WGConfigurer, error) { diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f3e216ccd..86968d06d 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -4,7 +4,6 @@ import ( "fmt" "net/netip" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device" @@ -32,14 +31,14 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, } } diff --git a/client/iface/iface.go b/client/iface/iface.go index accf5ce0a..1fb9c2691 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -6,12 +6,16 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) const ( @@ -22,14 +26,35 @@ const ( type WGAddress = device.WGAddress +type wgProxyFactory interface { + GetProxy() wgproxy.Proxy + Free() error +} + +type WGIFaceOpts struct { + IFaceName string + Address string + WGPort int + WGPrivKey string + MTU int + MobileArgs *device.MobileIFaceArguments + TransportNet transport.Net + FilterFn bind.FilterFn +} + // WGIface represents an interface instance type WGIface struct { tun WGTunDevice userspaceBind bool mu sync.Mutex - configurer device.WGConfigurer - filter device.PacketFilter + configurer device.WGConfigurer + filter device.PacketFilter + wgProxyFactory wgProxyFactory +} + +func (w *WGIface) GetProxy() wgproxy.Proxy { + return w.wgProxyFactory.GetProxy() } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -124,22 +149,26 @@ func (w *WGIface) Close() error { w.mu.Lock() defer w.mu.Unlock() - err := w.tun.Close() - if err != nil { - return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err) + var result *multierror.Error + + if err := w.wgProxyFactory.Free(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) } - err = w.waitUntilRemoved() - if err != nil { + if err := w.tun.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) + } + + if err := w.waitUntilRemoved(); err != nil { log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) - err = w.Destroy() - if err != nil { - return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err) + if err := w.Destroy(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)) + return errors.FormatErrorOrNil(result) } log.Infof("interface %s successfully removed", w.Name()) } - return nil + return errors.FormatErrorOrNil(result) } // SetFilter sets packet filters for the userspace implementation diff --git a/client/iface/iface_android.go b/client/iface/iface_android.go deleted file mode 100644 index 5ed476e70..000000000 --- a/client/iface/iface_android.go +++ /dev/null @@ -1,43 +0,0 @@ -package iface - -import ( - "fmt" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{ - tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), - userspaceBind: true, - } - return wgIFace, nil -} - -// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { - w.mu.Lock() - defer w.mu.Unlock() - - cfgr, err := w.tun.Create(routes, dns, searchDomains) - if err != nil { - return err - } - w.configurer = cfgr - return nil -} - -// Create this function make sense on mobile only -func (w *WGIface) Create() error { - return fmt.Errorf("this function has not implemented on this platform") -} diff --git a/client/iface/iface_create.go b/client/iface/iface_create.go index f389019ed..5e17c6d41 100644 --- a/client/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -2,6 +2,8 @@ package iface +import "fmt" + // Create creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. // this function is different on Android @@ -17,3 +19,8 @@ func (w *WGIface) Create() error { w.configurer = cfgr return nil } + +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on non mobile") +} diff --git a/client/iface/iface_create_android.go b/client/iface/iface_create_android.go new file mode 100644 index 000000000..373a9c95a --- /dev/null +++ b/client/iface/iface_create_android.go @@ -0,0 +1,24 @@ +package iface + +import ( + "fmt" +) + +// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { + w.mu.Lock() + defer w.mu.Unlock() + + cfgr, err := w.tun.Create(routes, dns, searchDomains) + if err != nil { + return err + } + w.configurer = cfgr + return nil +} + +// Create this function make sense on mobile only +func (w *WGIface) Create() error { + return fmt.Errorf("this function has not implemented on this platform") +} diff --git a/client/iface/iface_darwin.go b/client/iface/iface_create_darwin.go similarity index 50% rename from client/iface/iface_darwin.go rename to client/iface/iface_create_darwin.go index b46ea0f80..1d91bce54 100644 --- a/client/iface/iface_darwin.go +++ b/client/iface/iface_create_darwin.go @@ -7,39 +7,8 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" ) -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{ - userspaceBind: true, - } - - if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) - return wgIFace, nil - } - - wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) - - return wgIFace, nil -} - -// CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("this function has not implemented on this platform") -} - // Create creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. // this function is different on Android @@ -65,3 +34,8 @@ func (w *WGIface) Create() error { return backoff.Retry(operation, backOff) } + +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on this platform") +} diff --git a/client/iface/iface_guid_windows.go b/client/iface/iface_guid_windows.go new file mode 100644 index 000000000..49492fd3d --- /dev/null +++ b/client/iface/iface_guid_windows.go @@ -0,0 +1,10 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/device" +) + +// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only +func (w *WGIface) GetInterfaceGUIDString() (string, error) { + return w.tun.(*device.TunDevice).GetInterfaceGUIDString() +} diff --git a/client/iface/iface_ios.go b/client/iface/iface_ios.go deleted file mode 100644 index fc0214748..000000000 --- a/client/iface/iface_ios.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build ios - -package iface - -import ( - "fmt" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - wgIFace := &WGIface{ - tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), - userspaceBind: true, - } - return wgIFace, nil -} - -// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("this function has not implemented on this platform") -} diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go index 703da9ce0..d91a7224f 100644 --- a/client/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type MockWGIface struct { @@ -30,6 +31,7 @@ type MockWGIface struct { GetDeviceFunc func() *device.FilteredDevice GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) + GetProxyFunc func() wgproxy.Proxy } func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { @@ -103,3 +105,8 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice { func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { return m.GetStatsFunc(peerKey) } + +func (m *MockWGIface) GetProxy() wgproxy.Proxy { + //TODO implement me + panic("implement me") +} diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go new file mode 100644 index 000000000..69a8d1fd4 --- /dev/null +++ b/client/iface/iface_new_android.go @@ -0,0 +1,24 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + + wgIFace := &WGIface{ + userspaceBind: true, + tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil +} diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go new file mode 100644 index 000000000..a92d74e0f --- /dev/null +++ b/client/iface/iface_new_darwin.go @@ -0,0 +1,34 @@ +//go:build !ios + +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + + var tun WGTunDevice + if netstack.IsEnabled() { + tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + } else { + tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + } + + wgIFace := &WGIface{ + userspaceBind: true, + tun: tun, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil +} diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go new file mode 100644 index 000000000..363f95e11 --- /dev/null +++ b/client/iface/iface_new_ios.go @@ -0,0 +1,26 @@ +//go:build ios + +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + + wgIFace := &WGIface{ + tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil +} diff --git a/client/iface/iface_new_unix.go b/client/iface/iface_new_unix.go new file mode 100644 index 000000000..f10b17c9a --- /dev/null +++ b/client/iface/iface_new_unix.go @@ -0,0 +1,45 @@ +//go:build (linux && !android) || freebsd + +package iface + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{} + + if netstack.IsEnabled() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + if device.WireGuardModuleIsLoaded() { + wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) + wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort) + return wgIFace, nil + } + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + return nil, fmt.Errorf("couldn't check or load tun module") +} diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go new file mode 100644 index 000000000..2e6355496 --- /dev/null +++ b/client/iface/iface_new_windows.go @@ -0,0 +1,32 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + + var tun WGTunDevice + if netstack.IsEnabled() { + tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + } else { + tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + } + + wgIFace := &WGIface{ + userspaceBind: true, + tun: tun, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil + +} diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index 87a68addb..85db9cacb 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -45,7 +45,16 @@ func TestWGIface_UpdateAddr(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: addr, + WGPort: wgPort, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -118,7 +127,16 @@ func Test_CreateInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: 33100, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -153,7 +171,16 @@ func Test_Close(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: wgPort, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -189,7 +216,16 @@ func TestRecreation(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: wgPort, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -252,7 +288,15 @@ func Test_ConfigureInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: wgPort, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -300,7 +344,16 @@ func Test_UpdatePeer(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: 33100, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -361,7 +414,16 @@ func Test_RemovePeer(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: 33100, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -418,7 +480,15 @@ func Test_ConnectPeers(t *testing.T) { guid := fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) + optsPeer1 := WGIFaceOpts{ + IFaceName: peer1ifaceName, + Address: peer1wgIP, + WGPort: peer1wgPort, + WGPrivKey: peer1Key.String(), + MTU: DefaultMTU, + TransportNet: newNet, + } + iface1, err := NewWGIFace(optsPeer1) if err != nil { t.Fatal(err) } @@ -432,7 +502,12 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } - peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort)) + localIP, err := getLocalIP() + if err != nil { + t.Fatal(err) + } + + peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer1wgPort)) if err != nil { t.Fatal(err) } @@ -444,7 +519,17 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil) + + optsPeer2 := WGIFaceOpts{ + IFaceName: peer2ifaceName, + Address: peer2wgIP, + WGPort: peer2wgPort, + WGPrivKey: peer2Key.String(), + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface2, err := NewWGIFace(optsPeer2) if err != nil { t.Fatal(err) } @@ -458,7 +543,7 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } - peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort)) + peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer2wgPort)) if err != nil { t.Fatal(err) } @@ -527,3 +612,28 @@ func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { } return wgtypes.Peer{}, fmt.Errorf("peer not found") } + +func getLocalIP() (string, error) { + // Get all interfaces + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", err + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + if ipNet.IP.IsLoopback() { + continue + } + + if ipNet.IP.To4() == nil { + continue + } + return ipNet.IP.String(), nil + } + + return "", fmt.Errorf("no local IP found") +} diff --git a/client/iface/iface_unix.go b/client/iface/iface_unix.go deleted file mode 100644 index 09dbb2c1f..000000000 --- a/client/iface/iface_unix.go +++ /dev/null @@ -1,49 +0,0 @@ -//go:build (linux && !android) || freebsd - -package iface - -import ( - "fmt" - "runtime" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{} - - // move the kernel/usp/netstack preference evaluation to upper layer - if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) - wgIFace.userspaceBind = true - return wgIFace, nil - } - - if device.WireGuardModuleIsLoaded() { - wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) - wgIFace.userspaceBind = false - return wgIFace, nil - } - - if !device.ModuleTunIsLoaded() { - return nil, fmt.Errorf("couldn't check or load tun module") - } - wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) - wgIFace.userspaceBind = true - return wgIFace, nil -} - -// CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS) -} diff --git a/client/iface/iface_windows.go b/client/iface/iface_windows.go deleted file mode 100644 index 6845ef3dd..000000000 --- a/client/iface/iface_windows.go +++ /dev/null @@ -1,41 +0,0 @@ -package iface - -import ( - "fmt" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{ - userspaceBind: true, - } - - if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) - return wgIFace, nil - } - - wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) - return wgIFace, nil -} - -// CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("this function has not implemented on non mobile") -} - -// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only -func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return w.tun.(*device.TunDevice).GetInterfaceGUIDString() -} diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go index cb6d7ccd9..f5ab29539 100644 --- a/client/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type IWGIface interface { @@ -22,6 +23,7 @@ type IWGIface interface { ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error + GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go index 6baeb66ae..96eec52a5 100644 --- a/client/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type IWGIface interface { @@ -20,6 +21,7 @@ type IWGIface interface { ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error + GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go new file mode 100644 index 000000000..e986d6d7b --- /dev/null +++ b/client/iface/wgproxy/bind/proxy.go @@ -0,0 +1,137 @@ +package bind + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/bind" +) + +type ProxyBind struct { + Bind *bind.ICEBind + + wgAddr *net.UDPAddr + wgEndpoint *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool +} + +// AddTurnConn adds a new connection to the bind. +// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the +// WireGuard configuration. +func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { + addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) + if err != nil { + return err + } + + p.wgAddr = addr + p.wgEndpoint = addrToEndpoint(addr) + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + return err + +} +func (p *ProxyBind) EndpointAddr() *net.UDPAddr { + return p.wgAddr +} + +func (p *ProxyBind) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + // Start the proxy only once + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyBind) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() +} + +func (p *ProxyBind) CloseConn() error { + if p.cancel == nil { + return fmt.Errorf("proxy not started") + } + return p.close() +} + +func (p *ProxyBind) close() error { + p.closeMu.Lock() + defer p.closeMu.Unlock() + + if p.closed { + return nil + } + p.closed = true + + p.cancel() + + p.Bind.RemoveEndpoint(p.wgAddr) + + return p.remoteConn.Close() +} + +func (p *ProxyBind) proxyToLocal(ctx context.Context) { + defer func() { + if err := p.close(); err != nil { + log.Warnf("failed to close remote conn: %s", err) + } + }() + + buf := make([]byte, 1500) + for { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + msg := bind.RecvMessage{ + Endpoint: p.wgEndpoint, + Buffer: buf[:n], + } + p.Bind.RecvChan <- msg + p.pausedMu.Unlock() + } +} + +func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { + ip, _ := netip.AddrFromSlice(addr.IP.To4()) + addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort} +} diff --git a/client/internal/wgproxy/ebpf/portlookup.go b/client/iface/wgproxy/ebpf/portlookup.go similarity index 91% rename from client/internal/wgproxy/ebpf/portlookup.go rename to client/iface/wgproxy/ebpf/portlookup.go index 0e2c20c99..fce8f1507 100644 --- a/client/internal/wgproxy/ebpf/portlookup.go +++ b/client/iface/wgproxy/ebpf/portlookup.go @@ -5,9 +5,9 @@ import ( "net" ) -const ( +var ( portRangeStart = 3128 - portRangeEnd = 3228 + portRangeEnd = portRangeStart + 100 ) type portLookup struct { diff --git a/client/internal/wgproxy/ebpf/portlookup_test.go b/client/iface/wgproxy/ebpf/portlookup_test.go similarity index 92% rename from client/internal/wgproxy/ebpf/portlookup_test.go rename to client/iface/wgproxy/ebpf/portlookup_test.go index 92f4b8eee..a2e92fc79 100644 --- a/client/internal/wgproxy/ebpf/portlookup_test.go +++ b/client/iface/wgproxy/ebpf/portlookup_test.go @@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) { func Test_portLookup_on_allocated(t *testing.T) { pl := portLookup{} + portRangeStart = 4128 + portRangeEnd = portRangeStart + 100 + allocatedPort, err := allocatePort(portRangeStart) if err != nil { t.Fatal(err) diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go similarity index 99% rename from client/internal/wgproxy/ebpf/proxy.go rename to client/iface/wgproxy/ebpf/proxy.go index e850f4533..e21fc35d4 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -119,7 +119,7 @@ func (p *WGEBPFProxy) Free() error { p.ctxCancel() var result *multierror.Error - if p.conn != nil { // p.conn will be nil if we have failed to listen + if p.conn != nil { if err := p.conn.Close(); err != nil { result = multierror.Append(result, err) } diff --git a/client/internal/wgproxy/ebpf/proxy_test.go b/client/iface/wgproxy/ebpf/proxy_test.go similarity index 100% rename from client/internal/wgproxy/ebpf/proxy_test.go rename to client/iface/wgproxy/ebpf/proxy_test.go diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go similarity index 95% rename from client/internal/wgproxy/ebpf/wrapper.go rename to client/iface/wgproxy/ebpf/wrapper.go index b6a8ac452..efd5fd946 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -28,7 +28,7 @@ type ProxyWrapper struct { isStarted bool } -func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go new file mode 100644 index 000000000..3ad7dc59d --- /dev/null +++ b/client/iface/wgproxy/factory_kernel.go @@ -0,0 +1,49 @@ +//go:build linux && !android + +package wgproxy + +import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +type KernelFactory struct { + wgPort int + + ebpfProxy *ebpf.WGEBPFProxy +} + +func NewKernelFactory(wgPort int) *KernelFactory { + f := &KernelFactory{ + wgPort: wgPort, + } + + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) + if err := ebpfProxy.Listen(); err != nil { + log.Infof("WireGuard Proxy Factory will produce UDP proxy") + log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) + return f + } + log.Infof("WireGuard Proxy Factory will produce eBPF proxy") + f.ebpfProxy = ebpfProxy + return f +} + +func (w *KernelFactory) GetProxy() Proxy { + if w.ebpfProxy == nil { + return udpProxy.NewWGUDPProxy(w.wgPort) + } + + return &ebpf.ProxyWrapper{ + WgeBPFProxy: w.ebpfProxy, + } +} + +func (w *KernelFactory) Free() error { + if w.ebpfProxy == nil { + return nil + } + return w.ebpfProxy.Free() +} diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go new file mode 100644 index 000000000..736944229 --- /dev/null +++ b/client/iface/wgproxy/factory_kernel_freebsd.go @@ -0,0 +1,29 @@ +package wgproxy + +import ( + log "github.com/sirupsen/logrus" + + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +// KernelFactory todo: check eBPF support on FreeBSD +type KernelFactory struct { + wgPort int +} + +func NewKernelFactory(wgPort int) *KernelFactory { + log.Infof("WireGuard Proxy Factory will produce UDP proxy") + f := &KernelFactory{ + wgPort: wgPort, + } + + return f +} + +func (w *KernelFactory) GetProxy() Proxy { + return udpProxy.NewWGUDPProxy(w.wgPort) +} + +func (w *KernelFactory) Free() error { + return nil +} diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go new file mode 100644 index 000000000..e2d479331 --- /dev/null +++ b/client/iface/wgproxy/factory_usp.go @@ -0,0 +1,30 @@ +package wgproxy + +import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/bind" + proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" +) + +type USPFactory struct { + bind *bind.ICEBind +} + +func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { + log.Infof("WireGuard Proxy Factory will produce bind proxy") + f := &USPFactory{ + bind: iceBind, + } + return f +} + +func (w *USPFactory) GetProxy() Proxy { + return &proxyBind.ProxyBind{ + Bind: w.bind, + } +} + +func (w *USPFactory) Free() error { + return nil +} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go new file mode 100644 index 000000000..243aa2bd2 --- /dev/null +++ b/client/iface/wgproxy/proxy.go @@ -0,0 +1,15 @@ +package wgproxy + +import ( + "context" + "net" +) + +// Proxy is a transfer layer between the relayed connection and the WireGuard +type Proxy interface { + AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error + EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint + Work() // Work start or resume the proxy + Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. + CloseConn() error +} diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go new file mode 100644 index 000000000..298c98cc0 --- /dev/null +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -0,0 +1,56 @@ +//go:build linux && !android + +package wgproxy + +import ( + "context" + "os" + "testing" + + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" +) + +func TestProxyCloseByRemoteConnEBPF(t *testing.T) { + if os.Getenv("GITHUB_ACTIONS") != "true" { + t.Skip("Skipping test as it requires root privileges") + } + ctx := context.Background() + + ebpfProxy := ebpf.NewWGEBPFProxy(51831) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %s", err) + } + + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %s", err) + } + }() + + tests := []struct { + name string + proxy Proxy + }{ + { + name: "ebpf proxy", + proxy: &ebpf.ProxyWrapper{ + WgeBPFProxy: ebpfProxy, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + relayedConn := newMockConn() + err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + if err != nil { + t.Errorf("error: %v", err) + } + + _ = relayedConn.Close() + if err := tt.proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} diff --git a/client/internal/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go similarity index 90% rename from client/internal/wgproxy/proxy_test.go rename to client/iface/wgproxy/proxy_test.go index b88ff3f83..64b617621 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" - "github.com/netbirdio/netbird/client/internal/wgproxy/usp" + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -84,7 +84,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }{ { name: "userspace proxy", - proxy: usp.NewWGUserSpaceProxy(51830), + proxy: udpProxy.NewWGUDPProxy(51830), }, } @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/iface/wgproxy/udp/proxy.go similarity index 73% rename from client/internal/wgproxy/usp/proxy.go rename to client/iface/wgproxy/udp/proxy.go index f73500717..200d961f3 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,19 +1,21 @@ -package usp +package udp import ( "context" + "errors" "fmt" + "io" "net" "sync" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" + cerrors "github.com/netbirdio/netbird/client/errors" ) -// WGUserSpaceProxy proxies -type WGUserSpaceProxy struct { +// WGUDPProxy proxies +type WGUDPProxy struct { localWGListenPort int remoteConn net.Conn @@ -28,10 +30,10 @@ type WGUserSpaceProxy struct { isStarted bool } -// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation -func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { +// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation +func NewWGUDPProxy(wgPort int) *WGUDPProxy { log.Debugf("Initializing new user space proxy with port %d", wgPort) - p := &WGUserSpaceProxy{ + p := &WGUDPProxy{ localWGListenPort: wgPort, } return p @@ -42,7 +44,7 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { // the connection is complete, an error is returned. Once successfully // connected, any expiration of the context will not affect the // connection. -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { +func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { dialer := net.Dialer{} localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { @@ -57,7 +59,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) return err } -func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { +func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr { if p.localConn == nil { return nil } @@ -66,7 +68,7 @@ func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { } // Work starts the proxy or resumes it if it was paused -func (p *WGUserSpaceProxy) Work() { +func (p *WGUDPProxy) Work() { if p.remoteConn == nil { return } @@ -83,7 +85,7 @@ func (p *WGUserSpaceProxy) Work() { } // Pause pauses the proxy from receiving data from the remote peer -func (p *WGUserSpaceProxy) Pause() { +func (p *WGUDPProxy) Pause() { if p.remoteConn == nil { return } @@ -94,14 +96,14 @@ func (p *WGUserSpaceProxy) Pause() { } // CloseConn close the localConn -func (p *WGUserSpaceProxy) CloseConn() error { +func (p *WGUDPProxy) CloseConn() error { if p.cancel == nil { return fmt.Errorf("proxy not started") } return p.close() } -func (p *WGUserSpaceProxy) close() error { +func (p *WGUDPProxy) close() error { p.closeMu.Lock() defer p.closeMu.Unlock() @@ -121,11 +123,11 @@ func (p *WGUserSpaceProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } - return errors.FormatErrorOrNil(result) + return cerrors.FormatErrorOrNil(result) } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { +func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -157,21 +159,19 @@ func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { // proxyToLocal proxies from the Remote peer to local WireGuard // if the proxy is paused it will drain the remote conn and drop the packets -func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { +func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { - log.Warnf("error in proxy to local loop: %s", err) + if !errors.Is(err, io.EOF) { + log.Warnf("error in proxy to local loop: %s", err) + } } }() buf := make([]byte, 1500) for { - n, err := p.remoteConn.Read(buf) + n, err := p.remoteConnRead(ctx, buf) if err != nil { - if ctx.Err() != nil { - return - } - log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } @@ -193,3 +193,15 @@ func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { } } } + +func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) { + n, err = p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err) + return + } + return +} diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index ce2a12af1..5bb0905d2 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -3,6 +3,7 @@ package acl import ( "crypto/md5" "encoding/hex" + "errors" "fmt" "net" "net/netip" @@ -10,14 +11,18 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" mgmProto "github.com/netbirdio/netbird/management/proto" ) +var ErrSourceRangesEmpty = errors.New("sources range is empty") + // Manager is a ACL rules manager type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap) @@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { - var newRouteRules = make(map[id.RuleID]struct{}) + newRouteRules := make(map[id.RuleID]struct{}, len(rules)) + var merr *multierror.Error + + // Apply new rules - firewall manager will return existing rule ID if already present for _, rule := range rules { id, err := d.applyRouteACL(rule) if err != nil { - return fmt.Errorf("apply route ACL: %w", err) + if errors.Is(err, ErrSourceRangesEmpty) { + log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) + } else { + merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) + } + continue } newRouteRules[id] = struct{}{} } + // Clean up old firewall rules for id := range d.routeRules { - if _, ok := newRouteRules[id]; !ok { + if _, exists := newRouteRules[id]; !exists { if err := d.firewall.DeleteRouteRule(id); err != nil { - log.Errorf("failed to delete route firewall rule: %v", err) - continue + merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err)) } - delete(d.routeRules, id) + // implicitly deleted from the map } } + d.routeRules = newRouteRules - return nil + return nberrors.FormatErrorOrNil(merr) } func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { if len(rule.SourceRanges) == 0 { - return "", fmt.Errorf("source ranges is empty") + return "", ErrSourceRangesEmpty } var sources []netip.Prefix diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 7d999669a..9a766021a 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,7 +1,6 @@ package acl import ( - "context" "net" "testing" @@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock) + fw, err := firewall.NewFirewall(ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return } defer func(fw manager.Manager) { - _ = fw.Reset() + _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) @@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock) + fw, err := firewall.NewFirewall(ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return } defer func(fw manager.Manager) { - _ = fw.Reset() + _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) diff --git a/client/internal/connect.go b/client/internal/connect.go index 74dc1f1b5..bcc9d17a3 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error { } // RunWithProbes runs the client's main logic with probes attached -func (c *ConnectClient) RunWithProbes( - probes *ProbeHolder, - runningChan chan error, -) error { +func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error { return c.run(MobileDependency{}, probes, runningChan) } @@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS( return c.run(mobileDependency, nil, nil) } -func (c *ConnectClient) run( - mobileDependency MobileDependency, - probes *ProbeHolder, - runningChan chan error, -) error { +func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error { defer func() { if r := recover(); r != nil { log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) @@ -117,12 +110,6 @@ func (c *ConnectClient) run( log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) - // Check if client was not shut down in a clean way and restore DNS config if required. - // Otherwise, we might not be able to connect to the management server to retrieve new config. - if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil { - log.Errorf("checking unclean shutdown error: %s", err) - } - backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -358,7 +345,11 @@ func (c *ConnectClient) Stop() error { if c.engine == nil { return nil } - return c.engine.Stop() + if err := c.engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } + + return nil } func (c *ConnectClient) isContextCancelled() bool { diff --git a/client/internal/dns/consts_freebsd.go b/client/internal/dns/consts_freebsd.go index 958eca8e5..64c8fe5eb 100644 --- a/client/internal/dns/consts_freebsd.go +++ b/client/internal/dns/consts_freebsd.go @@ -1,6 +1,5 @@ package dns const ( - fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager" + fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" ) diff --git a/client/internal/dns/consts_linux.go b/client/internal/dns/consts_linux.go index 32456a50f..15614b0c5 100644 --- a/client/internal/dns/consts_linux.go +++ b/client/internal/dns/consts_linux.go @@ -3,6 +3,5 @@ package dns const ( - fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager" + fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" ) diff --git a/client/internal/dns/file_repair_unix.go b/client/internal/dns/file_repair_unix.go index ae2c33b86..9a9218fa1 100644 --- a/client/internal/dns/file_repair_unix.go +++ b/client/internal/dns/file_repair_unix.go @@ -9,6 +9,8 @@ import ( "github.com/fsnotify/fsnotify" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) var ( @@ -20,7 +22,7 @@ var ( } ) -type repairConfFn func([]string, string, *resolvConf) error +type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error type repair struct { operationFile string @@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair { } } -func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) { +func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { if f.inotify != nil { return } @@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin log.Errorf("failed to rm inotify watch for resolv.conf: %s", err) } - err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf) + err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager) if err != nil { log.Errorf("failed to repair resolv.conf: %v", err) } diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index 4dba79e99..e948557b6 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/util" ) @@ -104,14 +105,14 @@ nameserver 8.8.8.8`, var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf) error { + updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(operationFile, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") + r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) if err != nil { @@ -151,14 +152,14 @@ searchdomain netbird.cloud something` var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf) error { + updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(tmpLink, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") + r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) if err != nil { diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 624e089cb..02ae26e10 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -11,6 +11,8 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -36,7 +38,7 @@ type fileConfigurator struct { nbNameserverIP string } -func newFileConfigurator() (hostManager, error) { +func newFileConfigurator() (*fileConfigurator, error) { fc := &fileConfigurator{} fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig) return fc, nil @@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool { return false } -func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { backupFileExist := f.isBackupFileExist() if !config.RouteAll { if backupFileExist { @@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { f.repair.stopWatchFileChanges() - err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf) + err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) if err != nil { return err } - f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP) + f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager) return nil } -func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error { +func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) nameServers := generateNsList(nbNameserverIP, cfg) @@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) // create another backup for unclean shutdown detection right after overwriting the original resolv.conf - if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil { + if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil { log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) } @@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return os.RemoveAll(fileDefaultResolvConfBackupLocation) } @@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add return restoreResolvConfFile() } - log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring") + log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress) return nil } @@ -192,10 +190,6 @@ func restoreResolvConfFile() error { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err) - } - return nil } diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index e55a07055..e2b5f699a 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -5,14 +5,14 @@ import ( "net/netip" "strings" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) type hostManager interface { - applyDNSConfig(config HostDNSConfig) error + applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error supportCustomPort() bool - restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error } type SystemDNSSettings struct { @@ -35,15 +35,15 @@ type DomainConfig struct { } type mockHostConfigurator struct { - applyDNSConfigFunc func(config HostDNSConfig) error + applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNSFunc func() error supportCustomPortFunc func() bool restoreUncleanShutdownDNSFunc func(*netip.Addr) error } -func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { if m.applyDNSConfigFunc != nil { - return m.applyDNSConfigFunc(config) + return m.applyDNSConfigFunc(config, stateManager) } return fmt.Errorf("method applyDNSSettings is not implemented") } @@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool { return false } -func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { - if m.restoreUncleanShutdownDNSFunc != nil { - return m.restoreUncleanShutdownDNSFunc(storedDNSAddress) - } - return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented") -} - func newNoopHostMocker() hostManager { return &mockHostConfigurator{ - applyDNSConfigFunc: func(config HostDNSConfig) error { return nil }, + applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil }, restoreHostDNSFunc: func() error { return nil }, supportCustomPortFunc: func() bool { return true }, restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil }, diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index 9230cb257..5653710d7 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -1,15 +1,17 @@ package dns -import "net/netip" +import ( + "github.com/netbirdio/netbird/client/internal/statemanager" +) type androidHostManager struct { } -func newHostManager() (hostManager, error) { +func newHostManager() (*androidHostManager, error) { return &androidHostManager{}, nil } -func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error { +func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { return nil } @@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error { func (a androidHostManager) supportCustomPort() bool { return false } - -func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error { - return nil -} diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 5dee305c2..b8ba33e34 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -8,12 +8,13 @@ import ( "fmt" "io" "net" - "net/netip" "os/exec" "strconv" "strings" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -37,7 +38,7 @@ type systemConfigurator struct { systemDNSSettings SystemDNSSettings } -func newHostManager() (hostManager, error) { +func newHostManager() (*systemConfigurator, error) { return &systemConfigurator{ createdKeys: make(map[string]struct{}), }, nil @@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool { return true } -func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error - // create a file for unclean shutdown detection - if err := createUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to create unclean shutdown file: %s", err) + if err := stateManager.UpdateState(&ShutdownState{}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } var ( @@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error { } } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown file: %s", err) - } - return nil } @@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) { return primaryService, router, nil } -func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (s *systemConfigurator) restoreUncleanShutdownDNS() error { if err := s.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via scutil: %w", err) } diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index ad8b14fb8..4a0acf572 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -3,9 +3,10 @@ package dns import ( "encoding/json" "fmt" - "net/netip" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) type iosHostManager struct { @@ -13,13 +14,13 @@ type iosHostManager struct { config HostDNSConfig } -func newHostManager(dnsManager IosDnsManager) (hostManager, error) { +func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { return &iosHostManager{ dnsManager: dnsManager, }, nil } -func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error { +func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { jsonData, err := json.Marshal(config) if err != nil { return fmt.Errorf("marshal: %w", err) @@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error { func (a iosHostManager) supportCustomPort() bool { return false } - -func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error { - return nil -} diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 72b8f6c6e..7bd4aec64 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -4,9 +4,9 @@ package dns import ( "bufio" - "errors" "fmt" "io" + "net/netip" "os" "strings" @@ -21,27 +21,8 @@ const ( resolvConfManager ) -var ErrUnknownOsManagerType = errors.New("unknown os manager type") - type osManagerType int -func newOsManagerType(osManager string) (osManagerType, error) { - switch osManager { - case "netbird": - return fileManager, nil - case "file": - return netbirdManager, nil - case "networkManager": - return networkManager, nil - case "systemd": - return systemdManager, nil - case "resolvconf": - return resolvConfManager, nil - default: - return 0, ErrUnknownOsManagerType - } -} - func (t osManagerType) String() string { switch t { case netbirdManager: @@ -59,6 +40,11 @@ func (t osManagerType) String() string { } } +type restoreHostManager interface { + hostManager + restoreUncleanShutdownDNS(*netip.Addr) error +} + func newHostManager(wgInterface string) (hostManager, error) { osManager, err := getOSDNSManagerType() if err != nil { @@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) { return newHostManagerFromType(wgInterface, osManager) } -func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) { +func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) { switch osManager { case networkManager: return newNetworkManagerDbusConfigurator(wgInterface) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index c8bf2e552..7ecca8a41 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -3,11 +3,12 @@ package dns import ( "fmt" "io" - "net/netip" "strings" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -31,7 +32,7 @@ type registryConfigurator struct { routingAll bool } -func newHostManager(wgInterface WGIface) (hostManager, error) { +func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { guid, err := wgInterface.GetInterfaceGUIDString() if err != nil { return nil, err @@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) { return newHostManagerWithGuid(guid) } -func newHostManagerWithGuid(guid string) (hostManager, error) { +func newHostManagerWithGuid(guid string) (*registryConfigurator, error) { return ®istryConfigurator{ guid: guid, }, nil @@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool { return false } -func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error if config.RouteAll { err = r.addDNSSetupForAll(config.ServerIP) @@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - // create a file for unclean shutdown detection - if err := createUncleanShutdownIndicator(r.guid); err != nil { - log.Errorf("failed to create unclean shutdown file: %s", err) + if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } var ( @@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error { return fmt.Errorf("remove interface registry key: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown file: %s", err) - } - return nil } @@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) { return regKey, nil } -func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (r *registryConfigurator) restoreUncleanShutdownDNS() error { if err := r.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via registry: %w", err) } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 184047a64..63bbead77 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -16,6 +16,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbversion "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{ type networkManagerDbusConfigurator struct { dbusLinkObject dbus.ObjectPath routingAll bool + ifaceName string } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() { } } -func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) { +func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) { obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) if err != nil { return nil, fmt.Errorf("get nm dbus: %w", err) @@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) return &networkManagerDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), + ifaceName: wgInterface, }, nil } @@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool { return false } -func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { connSettings, configVersion, err := n.getAppliedConnectionSettings() if err != nil { return fmt.Errorf("retrieving the applied connection settings, error: %w", err) @@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) - // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. - // The file content itself is not important for network-manager restoration - if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: networkManager, + WgIface: n.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) @@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("delete connection settings: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return nil } diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 0c17626c7..a5d1cc8a2 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -9,6 +9,8 @@ import ( "os/exec" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const resolvconfCommand = "resolvconf" @@ -22,7 +24,7 @@ type resolvconf struct { } // supported "openresolv" only -func newResolvConfConfigurator(wgInterface string) (hostManager, error) { +func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) { resolvConfEntries, err := parseDefaultResolvConf() if err != nil { log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err) @@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool { return false } -func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { +func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error if !config.RouteAll { err = r.restoreHostDNS() @@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { append([]string{config.ServerIP}, r.originalNameServers...), options) - // create a backup for unclean shutdown detection before the resolv.conf is changed - if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: resolvConfManager, + WgIface: r.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } err = r.applyConfig(buf) @@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error { cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) _, err := cmd.Output() if err != nil { - return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err) - } - - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) + return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err) } return nil @@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error { cmd.Stdin = &content _, err := cmd.Output() if err != nil { - return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err) + return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err) } return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index a4651ebb5..929e1e60c 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,6 +7,7 @@ import ( "runtime" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -14,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -63,6 +65,7 @@ type DefaultServer struct { iosDnsManager IosDnsManager statusRecorder *peer.Status + stateManager *statemanager.Manager } type handlerWithStop interface { @@ -77,12 +80,7 @@ type muxUpdate struct { } // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, -) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) { var addrPort *netip.AddrPort if customAddress != "" { parsedAddrPort, err := netip.ParseAddrPort(customAddress) @@ -99,7 +97,7 @@ func NewDefaultServer( dnsService = newServiceViaListener(wgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil + return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -112,7 +110,7 @@ func NewDefaultServerPermanentUpstream( statusRecorder *peer.Status, ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -130,12 +128,12 @@ func NewDefaultServerIos( iosDnsManager IosDnsManager, statusRecorder *peer.Status, ) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds.iosDnsManager = iosDnsManager return ds } -func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer { +func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ ctx: ctx, @@ -147,6 +145,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi }, wgInterface: wgInterface, statusRecorder: statusRecorder, + stateManager: stateManager, hostsDNSHolder: newHostsDNSHolder(), } @@ -169,6 +168,7 @@ func (s *DefaultServer) Initialize() (err error) { } } + s.stateManager.RegisterState(&ShutdownState{}) s.hostManager, err = s.initialize() if err != nil { return fmt.Errorf("initialize: %w", err) @@ -191,9 +191,10 @@ func (s *DefaultServer) Stop() { s.ctxCancel() if s.hostManager != nil { - err := s.hostManager.restoreHostDNS() - if err != nil { - log.Error(err) + if err := s.hostManager.restoreHostDNS(); err != nil { + log.Error("failed to restore host DNS settings: ", err) + } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { + log.Errorf("failed to delete shutdown dns state: %v", err) } } @@ -318,10 +319,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { hostUpdate.RouteAll = false } - if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { + if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil { log.Error(err) } + // persist dns state right away + ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) + defer cancel() + if err := s.stateManager.PersistState(ctx); err != nil { + log.Errorf("Failed to persist dns state: %v", err) + } + if s.searchDomainNotifier != nil { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) } @@ -521,10 +529,17 @@ func (s *DefaultServer) upstreamCallbacks( } } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } + // persist dns state right away + ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) + defer cancel() + if err := s.stateManager.PersistState(ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() } @@ -551,7 +566,7 @@ func (s *DefaultServer) upstreamCallbacks( s.currentConfig.RouteAll = true s.service.RegisterMux(nbdns.RootZone, handler) } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 53d18a678..21f1f1b7d 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" @@ -267,7 +268,17 @@ func TestUpdateDNSServer(t *testing.T) { if err != nil { t.Fatal(err) } - wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) + + opts := iface.WGIFaceOpts{ + IFaceName: fmt.Sprintf("utun230%d", n), + Address: fmt.Sprintf("100.66.100.%d/32", n+1), + WGPort: 33100, + WGPrivKey: privKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + + wgIface, err := iface.NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -281,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) if err != nil { t.Fatal(err) } @@ -345,7 +356,15 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { } privKey, _ := wgtypes.GeneratePrivateKey() - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: "utun2301", + Address: "100.66.100.1/32", + WGPort: 33100, + WGPrivKey: privKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgIface, err := iface.NewWGIFace(opts) if err != nil { t.Errorf("build interface wireguard: %v", err) return @@ -382,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -477,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil) if err != nil { t.Fatalf("%v", err) } @@ -536,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ + ctx: context.Background(), service: NewServiceViaMemory(&mocWGIface{}), localResolver: &localResolver{ registeredMap: make(registrationMap), @@ -552,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { } var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error { + hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { domains := []string{} for _, item := range config.Domains { if item.Disabled { @@ -803,7 +823,17 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { } privKey, _ := wgtypes.GeneratePrivateKey() - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) + + opts := iface.WGIFaceOpts{ + IFaceName: "utun2301", + Address: "100.66.100.2/24", + WGPort: 33100, + WGPrivKey: privKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + + wgIface, err := iface.NewWGIFace(opts) if err != nil { t.Fatalf("build interface wireguard: %v", err) return nil, err diff --git a/client/internal/dns/server_windows.go b/client/internal/dns/server_windows.go index 5e1494e9e..bc051d59b 100644 --- a/client/internal/dns/server_windows.go +++ b/client/internal/dns/server_windows.go @@ -1,5 +1,5 @@ package dns -func (s *DefaultServer) initialize() (manager hostManager, err error) { +func (s *DefaultServer) initialize() (hostManager, error) { return newHostManager(s.wgInterface) } diff --git a/client/internal/dns/systemd_freebsd.go b/client/internal/dns/systemd_freebsd.go index 0de805337..41c8bf019 100644 --- a/client/internal/dns/systemd_freebsd.go +++ b/client/internal/dns/systemd_freebsd.go @@ -7,7 +7,7 @@ import ( var errNotImplemented = errors.New("not implemented") -func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { +func newSystemdDbusConfigurator(string) (restoreHostManager, error) { return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented) } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index e2fa5b71a..a031be582 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -38,6 +39,7 @@ const ( type systemdDbusConfigurator struct { dbusLinkObject dbus.ObjectPath routingAll bool + ifaceName string } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct { MatchOnly bool } -func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { +func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) { iface, err := net.InterfaceByName(wgInterface) if err != nil { return nil, fmt.Errorf("get interface: %w", err) @@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { return &systemdDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), + ifaceName: wgInterface, }, nil } @@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { return true } -func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { parsedIP, err := netip.ParseAddr(config.ServerIP) if err != nil { return fmt.Errorf("unable to parse ip address, error: %w", err) @@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) } - // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. - // The file content itself is not important for systemd restoration - if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: systemdManager, + WgIface: s.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) @@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("unable to revert link configuration, got error: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return s.flushCaches() } diff --git a/client/internal/dns/unclean_shutdown_android.go b/client/internal/dns/unclean_shutdown_android.go deleted file mode 100644 index 105fb00bf..000000000 --- a/client/internal/dns/unclean_shutdown_android.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -func CheckUncleanShutdown(string) error { - return nil -} diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index e077ec84d..9bbdd2b56 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -3,57 +3,25 @@ package dns import ( - "errors" "fmt" - "io/fs" - "os" - "path/filepath" - - log "github.com/sirupsen/logrus" ) -const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns" +type ShutdownState struct { +} -func CheckUncleanShutdown(string) error { - if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } - - log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation) +func (s *ShutdownState) Name() string { + return "dns_state" +} +func (s *ShutdownState) Cleanup() error { manager, err := newHostManager() if err != nil { return fmt.Errorf("create host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(nil); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } - -func createUncleanShutdownIndicator() error { - dir := filepath.Dir(fileUncleanShutdownFileLocation) - if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { - return fmt.Errorf("create dir %s: %w", dir, err) - } - - if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec - return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err) - } - - return nil -} - -func removeUncleanShutdownIndicator() error { - if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err) - } - return nil -} diff --git a/client/internal/dns/unclean_shutdown_ios.go b/client/internal/dns/unclean_shutdown_ios.go deleted file mode 100644 index 105fb00bf..000000000 --- a/client/internal/dns/unclean_shutdown_ios.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -func CheckUncleanShutdown(string) error { - return nil -} diff --git a/client/internal/dns/unclean_shutdown_mobile.go b/client/internal/dns/unclean_shutdown_mobile.go new file mode 100644 index 000000000..0d3a2cdbd --- /dev/null +++ b/client/internal/dns/unclean_shutdown_mobile.go @@ -0,0 +1,14 @@ +//go:build ios || android + +package dns + +type ShutdownState struct { +} + +func (s *ShutdownState) Name() string { + return "dns_state" +} + +func (s *ShutdownState) Cleanup() error { + return nil +} diff --git a/client/internal/dns/unclean_shutdown_unix.go b/client/internal/dns/unclean_shutdown_unix.go index 8a32090c3..fcf60c694 100644 --- a/client/internal/dns/unclean_shutdown_unix.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -3,66 +3,44 @@ package dns import ( - "errors" "fmt" - "io/fs" "net/netip" "os" "path/filepath" - "strings" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" ) -func CheckUncleanShutdown(wgIface string) error { - if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } +type ShutdownState struct { + ManagerType osManagerType + DNSAddress netip.Addr + WgIface string +} - log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation) +func (s *ShutdownState) Name() string { + return "dns_state" +} - managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation) - if err != nil { - return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err) - } - - managerFields := strings.Split(string(managerData), ",") - if len(managerFields) < 2 { - return errors.New("split manager data: insufficient number of fields") - } - osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1] - - dnsAddress, err := netip.ParseAddr(dnsAddressStr) - if err != nil { - return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err) - } - - log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr) - - // determine os manager type, so we can invoke the respective restore action - osManagerType, err := newOsManagerType(osManagerTypeStr) - if err != nil { - return fmt.Errorf("detect previous host manager: %w", err) - } - - manager, err := newHostManagerFromType(wgIface, osManagerType) +func (s *ShutdownState) Cleanup() error { + manager, err := newHostManagerFromType(s.WgIface, s.ManagerType) if err != nil { return fmt.Errorf("create previous host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } -func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error { +// TODO: move file contents to state manager +func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error { + dnsAddress, err := netip.ParseAddr(dnsAddressStr) + if err != nil { + return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err) + } + dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { return fmt.Errorf("create dir %s: %w", dir, err) @@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType return fmt.Errorf("create %s: %w", sourcePath, err) } - managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress) - - if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec - return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err) - } - return nil -} - -func removeUncleanShutdownIndicator() error { - if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err) - } - if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err) + state := &ShutdownState{ + ManagerType: fileManager, + DNSAddress: dnsAddress, } + if err := stateManager.UpdateState(state); err != nil { + return fmt.Errorf("update state: %w", err) + } + return nil } diff --git a/client/internal/dns/unclean_shutdown_windows.go b/client/internal/dns/unclean_shutdown_windows.go index 41db46768..74e40cc11 100644 --- a/client/internal/dns/unclean_shutdown_windows.go +++ b/client/internal/dns/unclean_shutdown_windows.go @@ -1,75 +1,26 @@ package dns import ( - "errors" "fmt" - "io/fs" - "os" - "path/filepath" - - "github.com/sirupsen/logrus" ) -const ( - netbirdProgramDataLocation = "Netbird" - fileUncleanShutdownFile = "unclean_shutdown_dns.txt" -) +type ShutdownState struct { + Guid string +} -func CheckUncleanShutdown(string) error { - file := getUncleanShutdownFile() +func (s *ShutdownState) Name() string { + return "dns_state" +} - if _, err := os.Stat(file); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } - - logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file) - - guid, err := os.ReadFile(file) - if err != nil { - return fmt.Errorf("read %s: %w", file, err) - } - - manager, err := newHostManagerWithGuid(string(guid)) +func (s *ShutdownState) Cleanup() error { + manager, err := newHostManagerWithGuid(s.Guid) if err != nil { return fmt.Errorf("create host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(nil); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } - -func createUncleanShutdownIndicator(guid string) error { - file := getUncleanShutdownFile() - - dir := filepath.Dir(file) - if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { - return fmt.Errorf("create dir %s: %w", dir, err) - } - - if err := os.WriteFile(file, []byte(guid), 0600); err != nil { - return fmt.Errorf("create %s: %w", file, err) - } - - return nil -} - -func removeUncleanShutdownIndicator() error { - file := getUncleanShutdownFile() - - if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", file, err) - } - return nil -} - -func getUncleanShutdownFile() string { - return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile) -} diff --git a/client/internal/engine.go b/client/internal/engine.go index eac8ec098..2bf1c090c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -23,19 +23,21 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" - - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/client/internal/wgproxy" + "github.com/netbirdio/netbird/client/internal/statemanager" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -141,8 +143,7 @@ type Engine struct { ctx context.Context cancel context.CancelFunc - wgInterface iface.IWGIface - wgProxyFactory *wgproxy.Factory + wgInterface iface.IWGIface udpMux *bind.UniversalUDPMuxDefault @@ -168,6 +169,8 @@ type Engine struct { checks []*mgmProto.Checks relayManager *relayClient.Manager + stateManager *statemanager.Manager + srWatcher *guard.SRWatcher } // Peer is an instance of the Connection Peer @@ -215,7 +218,7 @@ func NewEngineWithProbes( probes *ProbeHolder, checks []*mgmProto.Checks, ) *Engine { - return &Engine{ + engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, signal: signalClient, @@ -234,6 +237,11 @@ func NewEngineWithProbes( probes: probes, checks: checks, } + if path := statemanager.GetDefaultStatePath(); path != "" { + engine.stateManager = statemanager.New(path) + } + + return engine } func (e *Engine) Stop() error { @@ -255,7 +263,11 @@ func (e *Engine) Stop() error { e.stopDNSServer() if e.routeManager != nil { - e.routeManager.Stop() + e.routeManager.Stop(e.stateManager) + } + + if e.srWatcher != nil { + e.srWatcher.Close() } err := e.removeAllPeers() @@ -277,6 +289,17 @@ func (e *Engine) Stop() error { e.close() log.Infof("stopped Netbird Engine") + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := e.stateManager.Stop(ctx); err != nil { + return fmt.Errorf("failed to stop state manager: %w", err) + } + if err := e.stateManager.PersistState(ctx); err != nil { + log.Errorf("failed to persist state: %v", err) + } + return nil } @@ -299,9 +322,6 @@ func (e *Engine) Start() error { } e.wgInterface = wgIface - userspace := e.wgInterface.IsUserspaceBind() - e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort) - if e.config.RosenpassEnabled { log.Infof("rosenpass is enabled") if e.config.RosenpassPermissive { @@ -319,6 +339,8 @@ func (e *Engine) Start() error { } } + e.stateManager.Start() + initialRoutes, dnsServer, err := e.newDnsServer() if err != nil { e.close() @@ -327,7 +349,7 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() + beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) if err != nil { log.Errorf("Failed to initialize route manager: %s", err) } else { @@ -344,7 +366,7 @@ func (e *Engine) Start() error { return fmt.Errorf("create wg interface: %w", err) } - e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) if err != nil { log.Errorf("failed creating firewall manager: %s", err) } @@ -374,6 +396,18 @@ func (e *Engine) Start() error { return fmt.Errorf("initialize dns server: %w", err) } + iceCfg := icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.UDPMuxDefault, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + } + + e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) + e.srWatcher.Start() + e.receiveSignalEvents() e.receiveManagementEvents() e.receiveProbeEvents() @@ -956,7 +990,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e LocalWgPort: e.config.WgPort, RosenpassPubKey: e.getRosenpassPubKey(), RosenpassAddr: e.getRosenpassAddr(), - ICEConfig: peer.ICEConfig{ + ICEConfig: icemaker.Config{ StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, @@ -966,7 +1000,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) if err != nil { return nil, err } @@ -1117,12 +1151,6 @@ func (e *Engine) parseNATExternalIPMappings() []string { } func (e *Engine) close() { - if e.wgProxyFactory != nil { - if err := e.wgProxyFactory.Free(); err != nil { - log.Errorf("failed closing ebpf proxy: %s", err) - } - } - log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { @@ -1139,7 +1167,7 @@ func (e *Engine) close() { } if e.firewall != nil { - err := e.firewall.Reset() + err := e.firewall.Reset(e.stateManager) if err != nil { log.Warnf("failed to reset firewall: %s", err) } @@ -1167,21 +1195,29 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { log.Errorf("failed to create pion's stdnet: %s", err) } - var mArgs *device.MobileIFaceArguments + opts := iface.WGIFaceOpts{ + IFaceName: e.config.WgIfaceName, + Address: e.config.WgAddr, + WGPort: e.config.WgPort, + WGPrivKey: e.config.WgPrivateKey.String(), + MTU: iface.DefaultMTU, + TransportNet: transportNet, + FilterFn: e.addrViaRoutes, + } + switch runtime.GOOS { case "android": - mArgs = &device.MobileIFaceArguments{ + opts.MobileArgs = &device.MobileIFaceArguments{ TunAdapter: e.mobileDep.TunAdapter, TunFd: int(e.mobileDep.FileDescriptor), } case "ios": - mArgs = &device.MobileIFaceArguments{ + opts.MobileArgs = &device.MobileIFaceArguments{ TunFd: int(e.mobileDep.FileDescriptor), } - default: } - return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes) + return iface.NewWGIFace(opts) } func (e *Engine) wgInterfaceCreate() (err error) { @@ -1222,10 +1258,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) return nil, dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager) if err != nil { return nil, nil, err } + return nil, dnsServer, nil } } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 74b10ee44..0018af6df 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -29,6 +29,8 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { } engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.ctx = ctx + engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) type testCase struct { name string @@ -602,7 +605,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) + + opts := iface.WGIFaceOpts{ + IFaceName: wgIfaceName, + Address: wgAddr, + WGPort: engine.config.WgPort, + WGPrivKey: key.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + engine.wgInterface, err = iface.NewWGIFace(opts) assert.NoError(t, err, "shouldn't return error") input := struct { inputSerial uint64 @@ -774,7 +786,15 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: wgIfaceName, + Address: wgAddr, + WGPort: 33100, + WGPrivKey: key.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + engine.wgInterface, err = iface.NewWGIFace(opts) assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 1b740388d..56b772759 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -10,15 +10,16 @@ import ( "sync" "time" - "github.com/cenkalti/backoff/v4" "github.com/pion/ice/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/wgproxy" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -32,8 +33,6 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 - - reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -63,7 +62,7 @@ type ConnConfig struct { RosenpassAddr string // ICEConfig ICE protocol configuration - ICEConfig ICEConfig + ICEConfig icemaker.Config } type WorkerCallbacks struct { @@ -81,11 +80,10 @@ type Conn struct { ctxCancel context.CancelFunc config ConnConfig statusRecorder *Status - wgProxyFactory *wgproxy.Factory signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager - allowedIPsIP string + allowedIP net.IP + allowedNet string handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) @@ -107,17 +105,13 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy - // for reconnection operations - iCEDisconnected chan bool - relayDisconnected chan bool - connMonitor *ConnMonitor - reconnectCh <-chan struct{} + guard *guard.Guard } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { - _, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps) +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { + allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) return nil, err @@ -132,26 +126,14 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu ctxCancel: ctxCancel, config: config, statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, signaler: signaler, - iFaceDiscover: iFaceDiscover, relayManager: relayManager, - allowedIPsIP: allowedIPsIP.String(), + allowedIP: allowedIP, + allowedNet: allowedNet.String(), statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), - - iCEDisconnected: make(chan bool, 1), - relayDisconnected: make(chan bool, 1), } - conn.connMonitor, conn.reconnectCh = NewConnMonitor( - signaler, - iFaceDiscover, - config, - conn.relayDisconnected, - conn.iCEDisconnected, - ) - rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, OnDisconnected: conn.onWorkerRelayStateDisconnected, @@ -162,7 +144,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu OnStatusChanged: conn.onWorkerICEStateDisconnected, } - conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns) + ctrl := isController(config) + conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) @@ -177,6 +160,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } + conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher) + go conn.handshaker.Listen() return conn, nil @@ -187,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu // be used. func (conn *Conn) Open() { conn.log.Debugf("open connection to peer") + conn.mu.Lock() defer conn.mu.Unlock() conn.opened = true @@ -203,24 +189,19 @@ func (conn *Conn) Open() { conn.log.Warnf("error while updating the state err: %v", err) } - go conn.startHandshakeAndReconnect() + go conn.startHandshakeAndReconnect(conn.ctx) } -func (conn *Conn) startHandshakeAndReconnect() { - conn.waitInitialRandomSleepTime() +func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { + conn.waitInitialRandomSleepTime(ctx) err := conn.handshaker.sendOffer() if err != nil { conn.log.Errorf("failed to send initial offer: %v", err) } - go conn.connMonitor.Start(conn.ctx) - - if conn.workerRelay.IsController() { - conn.reconnectLoopWithRetry() - } else { - conn.reconnectLoopForOnDisconnectedEvent() - } + go conn.guard.Start(ctx) + go conn.listenGuardEvent(ctx) } // Close closes this peer Conn issuing a close event to the Conn closeCh @@ -319,104 +300,6 @@ func (conn *Conn) GetKey() string { return conn.config.Key } -func (conn *Conn) reconnectLoopWithRetry() { - // Give chance to the peer to establish the initial connection. - // With it, we can decrease to send necessary offer - select { - case <-conn.ctx.Done(): - return - case <-time.After(3 * time.Second): - } - - ticker := conn.prepareExponentTicker() - defer ticker.Stop() - time.Sleep(1 * time.Second) - - for { - select { - case t := <-ticker.C: - if t.IsZero() { - // in case if the ticker has been canceled by context then avoid the temporary loop - return - } - - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected { - conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) - } - } else { - if conn.statusICE.Get() == StatusDisconnected { - conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE) - } - } - - // checks if there is peer connection is established via relay or ice - if conn.isConnected() { - continue - } - - err := conn.handshaker.sendOffer() - if err != nil { - conn.log.Errorf("failed to do handshake: %v", err) - } - - case <-conn.reconnectCh: - ticker.Stop() - ticker = conn.prepareExponentTicker() - - case <-conn.ctx.Done(): - conn.log.Debugf("context is done, stop reconnect loop") - return - } - } -} - -func (conn *Conn) prepareExponentTicker() *backoff.Ticker { - bo := backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.1, - Multiplier: 2, - MaxInterval: conn.config.Timeout, - MaxElapsedTime: reconnectMaxElapsedTime, - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, conn.ctx) - - ticker := backoff.NewTicker(bo) - <-ticker.C // consume the initial tick what is happening right after the ticker has been created - - return ticker -} - -// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer -// when the connection is lost. It will try to establish a connection only once time if before the connection was established -// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not -// mean that to switch to it. We always force to use the higher priority connection. -func (conn *Conn) reconnectLoopForOnDisconnectedEvent() { - for { - select { - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, try to send new offer") - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, try to send new offer") - case <-conn.ctx.Done(): - conn.log.Debugf("context is done, stop reconnect loop") - return - } - - err := conn.handshaker.SendOffer() - if err != nil { - conn.log.Errorf("failed to do handshake: %v", err) - } - } -} - // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { conn.mu.Lock() @@ -516,7 +399,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - conn.notifyReconnectLoopICEDisconnected(changed) + conn.guard.SetICEConnDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -607,7 +490,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - conn.notifyReconnectLoopRelayDisconnected(changed) + conn.guard.SetRelayedConnDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -620,6 +503,20 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { } } +func (conn *Conn) listenGuardEvent(ctx context.Context) { + for { + select { + case <-conn.guard.Reconnect: + conn.log.Debugf("send offer to peer") + if err := conn.handshaker.SendOffer(); err != nil { + conn.log.Errorf("failed to send offer: %v", err) + } + case <-ctx.Done(): + return + } + } +} + func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { return conn.config.WgConfig.WgInterface.UpdatePeer( conn.config.WgConfig.RemoteKey, @@ -692,11 +589,11 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) } } -func (conn *Conn) waitInitialRandomSleepTime() { +func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) { minWait := 100 maxWait := 800 duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond @@ -705,7 +602,7 @@ func (conn *Conn) waitInitialRandomSleepTime() { defer timeout.Stop() select { - case <-conn.ctx.Done(): + case <-ctx.Done(): case <-timeout.C: } } @@ -734,11 +631,17 @@ func (conn *Conn) evalStatus() ConnStatus { return StatusDisconnected } -func (conn *Conn) isConnected() bool { +func (conn *Conn) isConnectedOnAllWay() (connected bool) { conn.mu.Lock() defer conn.mu.Unlock() - if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting { + defer func() { + if !connected { + conn.logTraceConnState() + } + }() + + if conn.statusICE.Get() == StatusDisconnected { return false } @@ -783,8 +686,13 @@ func (conn *Conn) freeUpConnID() { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { conn.log.Debugf("setup proxied WireGuard connection") - wgProxy := conn.wgProxyFactory.GetProxy() - if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { + udpAddr := &net.UDPAddr{ + IP: conn.allowedIP, + Port: conn.config.WgConfig.WgListenPort, + } + + wgProxy := conn.config.WgConfig.WgInterface.GetProxy() + if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) return nil, err } @@ -803,20 +711,6 @@ func (conn *Conn) removeWgPeer() error { return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) } -func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { - select { - case conn.relayDisconnected <- changed: - default: - } -} - -func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { - select { - case conn.iCEDisconnected <- changed: - default: - } -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { @@ -829,6 +723,18 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { } } +func (conn *Conn) logTraceConnState() { + if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { + conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) + } else { + conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE) + } +} + +func isController(config ConnConfig) bool { + return config.LocalKey > config.Key +} + func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go deleted file mode 100644 index 75722c990..000000000 --- a/client/internal/peer/conn_monitor.go +++ /dev/null @@ -1,212 +0,0 @@ -package peer - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/pion/ice/v3" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/stdnet" -) - -const ( - signalerMonitorPeriod = 5 * time.Second - candidatesMonitorPeriod = 5 * time.Minute - candidateGatheringTimeout = 5 * time.Second -) - -type ConnMonitor struct { - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - config ConnConfig - relayDisconnected chan bool - iCEDisconnected chan bool - reconnectCh chan struct{} - currentCandidates []ice.Candidate - candidatesMu sync.Mutex -} - -func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { - reconnectCh := make(chan struct{}, 1) - cm := &ConnMonitor{ - signaler: signaler, - iFaceDiscover: iFaceDiscover, - config: config, - relayDisconnected: relayDisconnected, - iCEDisconnected: iCEDisconnected, - reconnectCh: reconnectCh, - } - return cm, reconnectCh -} - -func (cm *ConnMonitor) Start(ctx context.Context) { - signalerReady := make(chan struct{}, 1) - go cm.monitorSignalerReady(ctx, signalerReady) - - localCandidatesChanged := make(chan struct{}, 1) - go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) - - for { - select { - case changed := <-cm.relayDisconnected: - if !changed { - continue - } - log.Debugf("Relay state changed, triggering reconnect") - cm.triggerReconnect() - - case changed := <-cm.iCEDisconnected: - if !changed { - continue - } - log.Debugf("ICE state changed, triggering reconnect") - cm.triggerReconnect() - - case <-signalerReady: - log.Debugf("Signaler became ready, triggering reconnect") - cm.triggerReconnect() - - case <-localCandidatesChanged: - log.Debugf("Local candidates changed, triggering reconnect") - cm.triggerReconnect() - - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { - if cm.signaler == nil { - return - } - - ticker := time.NewTicker(signalerMonitorPeriod) - defer ticker.Stop() - - lastReady := true - for { - select { - case <-ticker.C: - currentReady := cm.signaler.Ready() - if !lastReady && currentReady { - select { - case signalerReady <- struct{}{}: - default: - } - } - lastReady = currentReady - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { - ufrag, pwd, err := generateICECredentials() - if err != nil { - log.Warnf("Failed to generate ICE credentials: %v", err) - return - } - - ticker := time.NewTicker(candidatesMonitorPeriod) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { - log.Warnf("Failed to handle candidate tick: %v", err) - } - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { - log.Debugf("Gathering ICE candidates") - - transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) - if err != nil { - log.Errorf("failed to create pion's stdnet: %s", err) - } - - agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) - if err != nil { - return fmt.Errorf("create ICE agent: %w", err) - } - defer func() { - if err := agent.Close(); err != nil { - log.Warnf("Failed to close ICE agent: %v", err) - } - }() - - gatherDone := make(chan struct{}) - err = agent.OnCandidate(func(c ice.Candidate) { - log.Tracef("Got candidate: %v", c) - if c == nil { - close(gatherDone) - } - }) - if err != nil { - return fmt.Errorf("set ICE candidate handler: %w", err) - } - - if err := agent.GatherCandidates(); err != nil { - return fmt.Errorf("gather ICE candidates: %w", err) - } - - ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) - defer cancel() - - select { - case <-ctx.Done(): - return fmt.Errorf("wait for gathering: %w", ctx.Err()) - case <-gatherDone: - } - - candidates, err := agent.GetLocalCandidates() - if err != nil { - return fmt.Errorf("get local candidates: %w", err) - } - log.Tracef("Got candidates: %v", candidates) - - if changed := cm.updateCandidates(candidates); changed { - select { - case localCandidatesChanged <- struct{}{}: - default: - } - } - - return nil -} - -func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { - cm.candidatesMu.Lock() - defer cm.candidatesMu.Unlock() - - if len(cm.currentCandidates) != len(newCandidates) { - cm.currentCandidates = newCandidates - return true - } - - for i, candidate := range cm.currentCandidates { - if candidate.Address() != newCandidates[i].Address() { - cm.currentCandidates = newCandidates - return true - } - } - - return false -} - -func (cm *ConnMonitor) triggerReconnect() { - select { - case cm.reconnectCh <- struct{}{}: - default: - } -} diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index b4926a9d2..039952588 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -10,8 +10,9 @@ import ( "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/peer/guard" + "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/util" ) @@ -20,7 +21,7 @@ var connConf = ConnConfig{ LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", Timeout: time.Second, LocalWgPort: 51820, - ICEConfig: ICEConfig{ + ICEConfig: ice.Config{ InterfaceBlackList: nil, }, } @@ -44,11 +45,8 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) if err != nil { return } @@ -59,11 +57,8 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } @@ -96,11 +91,8 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } @@ -132,11 +124,8 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go new file mode 100644 index 000000000..bf3527a62 --- /dev/null +++ b/client/internal/peer/guard/guard.go @@ -0,0 +1,194 @@ +package guard + +import ( + "context" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" +) + +const ( + reconnectMaxElapsedTime = 30 * time.Minute +) + +type isConnectedFunc func() bool + +// Guard is responsible for the reconnection logic. +// It will trigger to send an offer to the peer then has connection issues. +// Watch these events: +// - Relay client reconnected to home server +// - Signal server connection state changed +// - ICE connection disconnected +// - Relayed connection disconnected +// - ICE candidate changes +type Guard struct { + Reconnect chan struct{} + log *log.Entry + isController bool + isConnectedOnAllWay isConnectedFunc + timeout time.Duration + srWatcher *SRWatcher + relayedConnDisconnected chan bool + iCEConnDisconnected chan bool +} + +func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { + return &Guard{ + Reconnect: make(chan struct{}, 1), + log: log, + isController: isController, + isConnectedOnAllWay: isConnectedFn, + timeout: timeout, + srWatcher: srWatcher, + relayedConnDisconnected: make(chan bool, 1), + iCEConnDisconnected: make(chan bool, 1), + } +} + +func (g *Guard) Start(ctx context.Context) { + if g.isController { + g.reconnectLoopWithRetry(ctx) + } else { + g.listenForDisconnectEvents(ctx) + } +} + +func (g *Guard) SetRelayedConnDisconnected(changed bool) { + select { + case g.relayedConnDisconnected <- changed: + default: + } +} + +func (g *Guard) SetICEConnDisconnected(changed bool) { + select { + case g.iCEConnDisconnected <- changed: + default: + } +} + +// reconnectLoopWithRetry periodically check (max 30 min) the connection status. +// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported +func (g *Guard) reconnectLoopWithRetry(ctx context.Context) { + waitForInitialConnectionTry(ctx) + + srReconnectedChan := g.srWatcher.NewListener() + defer g.srWatcher.RemoveListener(srReconnectedChan) + + ticker := g.prepareExponentTicker(ctx) + defer ticker.Stop() + + tickerChannel := ticker.C + + g.log.Infof("start reconnect loop...") + for { + select { + case t := <-tickerChannel: + if t.IsZero() { + g.log.Infof("retry timed out, stop periodic offer sending") + // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop + tickerChannel = make(<-chan time.Time) + continue + } + + if !g.isConnectedOnAllWay() { + g.triggerOfferSending() + } + + case changed := <-g.relayedConnDisconnected: + if !changed { + continue + } + g.log.Debugf("Relay connection changed, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case changed := <-g.iCEConnDisconnected: + if !changed { + continue + } + g.log.Debugf("ICE connection changed, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case <-srReconnectedChan: + g.log.Debugf("has network changes, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case <-ctx.Done(): + g.log.Debugf("context is done, stop reconnect loop") + return + } + } +} + +// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer +// when the connection is lost. It will try to establish a connection only once time if before the connection was established +// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not +// mean that to switch to it. We always force to use the higher priority connection. +func (g *Guard) listenForDisconnectEvents(ctx context.Context) { + srReconnectedChan := g.srWatcher.NewListener() + defer g.srWatcher.RemoveListener(srReconnectedChan) + + g.log.Infof("start listen for reconnect events...") + for { + select { + case changed := <-g.relayedConnDisconnected: + if !changed { + continue + } + g.log.Debugf("Relay connection changed, triggering reconnect") + g.triggerOfferSending() + case changed := <-g.iCEConnDisconnected: + if !changed { + continue + } + g.log.Debugf("ICE state changed, try to send new offer") + g.triggerOfferSending() + case <-srReconnectedChan: + g.triggerOfferSending() + case <-ctx.Done(): + g.log.Debugf("context is done, stop reconnect loop") + return + } + } +} + +func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: 0.1, + Multiplier: 2, + MaxInterval: g.timeout, + MaxElapsedTime: reconnectMaxElapsedTime, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + + ticker := backoff.NewTicker(bo) + <-ticker.C // consume the initial tick what is happening right after the ticker has been created + + return ticker +} + +func (g *Guard) triggerOfferSending() { + select { + case g.Reconnect <- struct{}{}: + default: + } +} + +// Give chance to the peer to establish the initial connection. +// With it, we can decrease to send necessary offer +func waitForInitialConnectionTry(ctx context.Context) { + select { + case <-ctx.Done(): + return + case <-time.After(3 * time.Second): + } +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go new file mode 100644 index 000000000..b9c9aa134 --- /dev/null +++ b/client/internal/peer/guard/ice_monitor.go @@ -0,0 +1,135 @@ +package guard + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ICEMonitor struct { + ReconnectCh chan struct{} + + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig icemaker.Config + + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { + cm := &ICEMonitor{ + ReconnectCh: make(chan struct{}, 1), + iFaceDiscover: iFaceDiscover, + iceConfig: config, + } + return cm +} + +func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { + ufrag, pwd, err := icemaker.GenerateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + changed, err := cm.handleCandidateTick(ctx, ufrag, pwd) + if err != nil { + log.Warnf("Failed to check ICE changes: %v", err) + continue + } + + if changed { + onChanged() + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { + log.Debugf("Gathering ICE candidates") + + agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return false, fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return false, fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return false, fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return false, fmt.Errorf("wait for gathering timed out") + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return false, fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + return cm.updateCandidates(candidates), nil +} + +func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func candidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go new file mode 100644 index 000000000..90e45426f --- /dev/null +++ b/client/internal/peer/guard/sr_watcher.go @@ -0,0 +1,119 @@ +package guard + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +type chNotifier interface { + SetOnReconnectedListener(func()) + Ready() bool +} + +type SRWatcher struct { + signalClient chNotifier + relayManager chNotifier + + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config + + cancelIceMonitor context.CancelFunc +} + +// NewSRWatcher creates a new SRWatcher. This watcher will notify the listeners when the ICE candidates change or the +// Relay connection is reconnected or the Signal client reconnected. +func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscover stdnet.ExternalIFaceDiscover, iceConfig ice.Config) *SRWatcher { + srw := &SRWatcher{ + signalClient: signalClient, + relayManager: relayManager, + iFaceDiscover: iFaceDiscover, + iceConfig: iceConfig, + listeners: make(map[chan struct{}]struct{}), + } + return srw +} + +func (w *SRWatcher) Start() { + w.mu.Lock() + defer w.mu.Unlock() + + if w.cancelIceMonitor != nil { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + w.cancelIceMonitor = cancel + + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) + go iceMonitor.Start(ctx, w.onICEChanged) + w.signalClient.SetOnReconnectedListener(w.onReconnected) + w.relayManager.SetOnReconnectedListener(w.onReconnected) + +} + +func (w *SRWatcher) Close() { + w.mu.Lock() + defer w.mu.Unlock() + + if w.cancelIceMonitor == nil { + return + } + w.cancelIceMonitor() + w.signalClient.SetOnReconnectedListener(nil) + w.relayManager.SetOnReconnectedListener(nil) +} + +func (w *SRWatcher) NewListener() chan struct{} { + w.mu.Lock() + defer w.mu.Unlock() + + listenerChan := make(chan struct{}, 1) + w.listeners[listenerChan] = struct{}{} + return listenerChan +} + +func (w *SRWatcher) RemoveListener(listenerChan chan struct{}) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.listeners, listenerChan) + close(listenerChan) +} + +func (w *SRWatcher) onICEChanged() { + if !w.signalClient.Ready() { + return + } + + log.Infof("network changes detected by ICE agent") + w.notify() +} + +func (w *SRWatcher) onReconnected() { + if !w.signalClient.Ready() { + return + } + if !w.relayManager.Ready() { + return + } + + log.Infof("reconnected to Signal or Relay server") + w.notify() +} + +func (w *SRWatcher) notify() { + w.mu.Lock() + defer w.mu.Unlock() + for listener := range w.listeners { + select { + case listener <- struct{}{}: + default: + } + } +} diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go new file mode 100644 index 000000000..dc4750f24 --- /dev/null +++ b/client/internal/peer/ice/agent.go @@ -0,0 +1,87 @@ +package ice + +import ( + "time" + + "github.com/pion/ice/v3" + "github.com/pion/randutil" + "github.com/pion/stun/v2" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + lenUFrag = 16 + lenPwd = 32 + runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + iceKeepAliveDefault = 4 * time.Second + iceDisconnectedTimeoutDefault = 6 * time.Second + // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package + iceRelayAcceptanceMinWaitDefault = 2 * time.Second +) + +var ( + failedTimeout = 6 * time.Second +) + +func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.InterfaceBlackList), + UDPMux: config.UDPMux, + UDPMuxSrflx: config.UDPMuxSrflx, + NAT1To1IPs: config.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + +func GenerateICECredentials() (string, string, error) { + ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) + if err != nil { + return "", "", err + } + + pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) + if err != nil { + return "", "", err + } + return ufrag, pwd, nil +} + +func CandidateTypes() []ice.CandidateType { + if hasICEForceRelayConn() { + return []ice.CandidateType{ice.CandidateTypeRelay} + } + + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} +} + +func CandidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go new file mode 100644 index 000000000..8abc842f0 --- /dev/null +++ b/client/internal/peer/ice/config.go @@ -0,0 +1,22 @@ +package ice + +import ( + "sync/atomic" + + "github.com/pion/ice/v3" +) + +type Config struct { + // StunTurn is a list of STUN and TURN URLs + StunTurn *atomic.Value // []*stun.URI + + // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering + // (e.g. if eth0 is in the list, host candidate of this interface won't be used) + InterfaceBlackList []string + DisableIPv6Discovery bool + + UDPMux ice.UDPMux + UDPMuxSrflx ice.UniversalUDPMux + + NATExternalIPs []string +} diff --git a/client/internal/peer/env_config.go b/client/internal/peer/ice/env.go similarity index 80% rename from client/internal/peer/env_config.go rename to client/internal/peer/ice/env.go index 87b626df7..3b0cb74ad 100644 --- a/client/internal/peer/env_config.go +++ b/client/internal/peer/ice/env.go @@ -1,4 +1,4 @@ -package peer +package ice import ( "os" @@ -10,12 +10,19 @@ import ( ) const ( + envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" - envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" + + msgWarnInvalidValue = "invalid value %s set for %s, using default %v" ) +func hasICEForceRelayConn() bool { + disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) + return strings.ToLower(disconnectedTimeoutEnv) == "true" +} + func iceKeepAlive() time.Duration { keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec) if keepAliveEnv == "" { @@ -25,7 +32,7 @@ func iceKeepAlive() time.Duration { log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv) keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) + log.Warnf(msgWarnInvalidValue, keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) return iceKeepAliveDefault } @@ -41,7 +48,7 @@ func iceDisconnectedTimeout() time.Duration { log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) + log.Warnf(msgWarnInvalidValue, disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) return iceDisconnectedTimeoutDefault } @@ -57,14 +64,9 @@ func iceRelayAcceptanceMinWait() time.Duration { log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv) disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) + log.Warnf(msgWarnInvalidValue, iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) return iceRelayAcceptanceMinWaitDefault } return time.Duration(disconnectedTimeoutSec) * time.Second } - -func hasICEForceRelayConn() bool { - disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) - return strings.ToLower(disconnectedTimeoutEnv) == "true" -} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/ice/stdnet.go similarity index 94% rename from client/internal/peer/stdnet.go rename to client/internal/peer/ice/stdnet.go index 96d211dbc..3ce83727e 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/ice/stdnet.go @@ -1,6 +1,6 @@ //go:build !android -package peer +package ice import ( "github.com/netbirdio/netbird/client/internal/stdnet" diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/ice/stdnet_android.go similarity index 94% rename from client/internal/peer/stdnet_android.go rename to client/internal/peer/ice/stdnet_android.go index a39a03b1c..84c665e6f 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/ice/stdnet_android.go @@ -1,4 +1,4 @@ -package peer +package ice import "github.com/netbirdio/netbird/client/internal/stdnet" diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index c86c1858f..55894218d 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -5,52 +5,20 @@ import ( "fmt" "net" "net/netip" - "runtime" "sync" - "sync/atomic" "time" "github.com/pion/ice/v3" - "github.com/pion/randutil" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" ) -const ( - iceKeepAliveDefault = 4 * time.Second - iceDisconnectedTimeoutDefault = 6 * time.Second - // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package - iceRelayAcceptanceMinWaitDefault = 2 * time.Second - - lenUFrag = 16 - lenPwd = 32 - runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -) - -var ( - failedTimeout = 6 * time.Second -) - -type ICEConfig struct { - // StunTurn is a list of STUN and TURN URLs - StunTurn *atomic.Value // []*stun.URI - - // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering - // (e.g. if eth0 is in the list, host candidate of this interface won't be used) - InterfaceBlackList []string - DisableIPv6Discovery bool - - UDPMux ice.UDPMux - UDPMuxSrflx ice.UniversalUDPMux - - NATExternalIPs []string -} - type ICEConnInfo struct { RemoteConn net.Conn RosenpassPubKey []byte @@ -103,7 +71,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal conn: callBacks, } - localUfrag, localPwd, err := generateICECredentials() + localUfrag, localPwd, err := icemaker.GenerateICECredentials() if err != nil { return nil, err } @@ -125,10 +93,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { var preferredCandidateTypes []ice.CandidateType if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { w.selectedPriority = connPriorityICEP2P - preferredCandidateTypes = candidateTypesP2P() + preferredCandidateTypes = icemaker.CandidateTypesP2P() } else { w.selectedPriority = connPriorityICETurn - preferredCandidateTypes = candidateTypes() + preferredCandidateTypes = icemaker.CandidateTypes() } w.log.Debugf("recreate ICE agent") @@ -232,15 +200,10 @@ func (w *WorkerICE) Close() { } } -func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) - if err != nil { - w.log.Errorf("failed to create pion's stdnet: %s", err) - } - +func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) { w.sentExtraSrflx = false - agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) + agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) } @@ -365,36 +328,6 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } -func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: candidateTypes, - InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), - UDPMux: config.ICEConfig.UDPMux, - UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: ufrag, - LocalPwd: pwd, - } - - if config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - - return ice.NewAgent(agentConfig) -} - func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ @@ -435,21 +368,6 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool return false } -func candidateTypes() []ice.CandidateType { - if hasICEForceRelayConn() { - return []ice.CandidateType{ice.CandidateTypeRelay} - } - // TODO: remove this once we have refactored userspace proxy into the bind package - if runtime.GOOS == "ios" { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} - } - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} -} - -func candidateTypesP2P() []ice.CandidateType { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} -} - func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } @@ -460,16 +378,3 @@ func isRelayed(pair *ice.CandidatePair) bool { } return false } - -func generateICECredentials() (string, string, error) { - ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) - if err != nil { - return "", "", err - } - - pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) - if err != nil { - return "", "", err - } - return ufrag, pwd, nil -} diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index c02fccebc..c22dcdeda 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -31,6 +31,7 @@ type WorkerRelayCallbacks struct { type WorkerRelay struct { log *log.Entry + isController bool config ConnConfig relayManager relayClient.ManagerService callBacks WorkerRelayCallbacks @@ -44,9 +45,10 @@ type WorkerRelay struct { relaySupportedOnRemotePeer atomic.Bool } -func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { +func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { r := &WorkerRelay{ log: log, + isController: ctrl, config: config, relayManager: relayManager, callBacks: callbacks, @@ -80,6 +82,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Errorf("failed to open connection via Relay: %s", err) return } + w.relayLock.Lock() w.relayedConn = relayedConn w.relayLock.Unlock() @@ -136,10 +139,6 @@ func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool { return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally() } -func (w *WorkerRelay) IsController() bool { - return w.config.LocalKey > w.config.Key -} - func (w *WorkerRelay) RelayIsSupportedLocally() bool { return w.relayManager.HasRelayAddress() } @@ -212,7 +211,7 @@ func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { } func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string { - if w.IsController() { + if w.isController { return myRelayAddress } return remoteRelayAddress diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d7ddf7ae8..0a1c7dc56 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/client/internal/statemanager" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -31,14 +32,14 @@ import ( // Manager is a route manager interface type Manager interface { - Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error - Stop() + Stop(stateManager *statemanager.Manager) } // DefaultManager is the default instance of a route manager @@ -120,12 +121,12 @@ func NewManager( } // Init sets up the routing -func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } - if err := m.sysOps.CleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(nil); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -136,7 +137,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } @@ -154,7 +155,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop stops the manager watchers and clean firewall rules -func (m *DefaultManager) Stop() { +func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() if m.serverRouter != nil { m.serverRouter.cleanUp() @@ -172,7 +173,7 @@ func (m *DefaultManager) Stop() { } if !nbnet.CustomRoutingDisabled() { - if err := m.sysOps.CleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(stateManager); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2f26f7a5e..e669bc44a 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -407,7 +407,15 @@ func TestManagerUpdateRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: fmt.Sprintf("utun43%d", n), + Address: "100.65.65.2/24", + WGPort: 33100, + WGPrivKey: peerPrivateKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgInterface, err := iface.NewWGIFace(opts) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -418,10 +426,10 @@ func TestManagerUpdateRoutes(t *testing.T) { ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) - _, _, err = routeManager.Init() + _, _, err = routeManager.Init(nil) require.NoError(t, err, "should init route manager") - defer routeManager.Stop() + defer routeManager.Stop(nil) if testCase.removeSrvRouter { routeManager.serverRouter = nil diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 908279c88..503185f03 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util/net" ) @@ -17,10 +18,10 @@ type MockManager struct { UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func() + StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { +func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } @@ -65,8 +66,8 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop mock implementation of Stop from Manager interface -func (m *MockManager) Stop() { +func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { - m.StopFunc() + m.StopFunc(stateManager) } } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 65ea0f708..c121b7d77 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -1,6 +1,7 @@ package refcounter import ( + "encoding/json" "errors" "fmt" "runtime" @@ -70,6 +71,19 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key } } +// LoadData loads the data from the existing counter +func (rm *Counter[Key, I, O]) LoadData( + existingCounter *Counter[Key, I, O], +) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + rm.refCountMap = existingCounter.refCountMap + rm.idMap = existingCounter.idMap +} + // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { @@ -201,6 +215,32 @@ func (rm *Counter[Key, I, O]) Clear() { clear(rm.idMap) } +// MarshalJSON implements the json.Marshaler interface for Counter. +func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + RefCountMap map[Key]Ref[O] `json:"refCountMap"` + IDMap map[string][]Key `json:"idMap"` + }{ + RefCountMap: rm.refCountMap, + IDMap: rm.idMap, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for Counter. +func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { + var temp struct { + RefCountMap map[Key]Ref[O] `json:"refCountMap"` + IDMap map[string][]Key `json:"idMap"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + rm.refCountMap = temp.RefCountMap + rm.idMap = temp.IDMap + + return nil +} + func getCallerInfo(depth int, maxDepth int) (string, bool) { if depth >= maxDepth { return "", false diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go new file mode 100644 index 000000000..425908922 --- /dev/null +++ b/client/internal/routemanager/systemops/state.go @@ -0,0 +1,32 @@ +package systemops + +import ( + "net/netip" + "sync" + + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" +) + +type ShutdownState struct { + Counter *ExclusionCounter `json:"counter,omitempty"` + mu sync.RWMutex +} + +func (s *ShutdownState) Name() string { + return "route_state" +} + +func (s *ShutdownState) Cleanup() error { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.Counter == nil { + return nil + } + + sysops := NewSysOps(nil, nil) + sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) + sysops.refCounter.LoadData(s.Counter) + + return sysops.refCounter.Flush() +} diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index 5e97a4a5f..ca8aea3fb 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -9,14 +9,15 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return nil, nil, nil } -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { return nil } @@ -28,6 +29,10 @@ func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error { return nil } +func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error { + return nil +} + func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 9258f4a4e..4ff34aa51 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -30,7 +31,9 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var ErrRoutingIsSeparate = errors.New("routing is separate") -func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + stateManager.RegisterState(&ShutdownState{}) + initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { log.Errorf("Unable to get initial v4 default next hop: %v", err) @@ -53,9 +56,18 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn // These errors are not critical, but also we should not track and try to remove the routes either. return nexthop, refcounter.ErrIgnore } + + r.updateState(stateManager) + return nexthop, err }, - r.removeFromRouteTable, + func(prefix netip.Prefix, nexthop Nexthop) error { + // remove from state even if we have trouble removing it from the route table + // it could be already gone + r.updateState(stateManager) + + return r.removeFromRouteTable(prefix, nexthop) + }, ) r.refCounter = refCounter @@ -63,7 +75,17 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn return r.setupHooks(initAddresses) } -func (r *SysOps) cleanupRefCounter() error { +func (r *SysOps) updateState(stateManager *statemanager.Manager) { + state := getState(stateManager) + + state.Counter = r.refCounter + + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + +func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { if r.refCounter == nil { return nil } @@ -76,6 +98,10 @@ func (r *SysOps) cleanupRefCounter() error { return fmt.Errorf("flush route manager: %w", err) } + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + return fmt.Errorf("delete state: %w", err) + } + return nil } @@ -506,3 +532,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } + +func getState(stateManager *statemanager.Manager) *ShutdownState { + var shutdownState *ShutdownState + if state := stateManager.GetState(shutdownState); state != nil { + shutdownState = state.(*ShutdownState) + } else { + shutdownState = &ShutdownState{} + } + + return shutdownState +} diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 238225807..5b7b13f97 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -61,7 +61,14 @@ func TestAddRemoveRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: fmt.Sprintf("utun53%d", n), + Address: "100.65.75.2/24", + WGPrivKey: peerPrivateKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgInterface, err := iface.NewWGIFace(opts) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -70,10 +77,10 @@ func TestAddRemoveRoutes(t *testing.T) { r := NewSysOps(wgInterface, nil) - _, _, err = r.SetupRouting(nil) + _, _, err = r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting()) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) @@ -213,7 +220,15 @@ func TestAddExistAndRemoveRoute(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: fmt.Sprintf("utun53%d", n), + Address: "100.65.75.2/24", + WGPort: 33100, + WGPrivKey: peerPrivateKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgInterface, err := iface.NewWGIFace(opts) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -345,7 +360,15 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen newNet, err := stdnet.NewNet() require.NoError(t, err) - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: interfaceName, + Address: ipAddressCIDR, + WGPrivKey: peerPrivateKey.String(), + WGPort: listenPort, + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgInterface, err := iface.NewWGIFace(opts) require.NoError(t, err, "should create testing WireGuard interface") err = wgInterface.Create() @@ -380,10 +403,10 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil) + _, _, err := r.SetupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting()) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 7cfb2b298..bf06f3739 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -9,17 +9,18 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil, nil, nil } -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() @@ -46,6 +47,18 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error { return nil } +func (r *SysOps) notify() { + prefixes := make([]netip.Prefix, 0, len(r.prefixes)) + for prefix := range r.prefixes { + prefixes = append(prefixes, prefix) + } + r.notifier.OnNewPrefixes(prefixes) +} + +func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error { + return nil +} + func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil @@ -54,11 +67,3 @@ func EnableIPForwarding() error { func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) { return false, netip.Prefix{} } - -func (r *SysOps) notify() { - prefixes := make([]netip.Prefix, 0, len(r.prefixes)) - for prefix := range r.prefixes { - prefixes = append(prefixes, prefix) - } - r.notifier.OnNewPrefixes(prefixes) -} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 2d0c57826..0124fd95e 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -18,6 +18,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -85,10 +86,10 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { if isLegacy() { log.Infof("Using legacy routing setup") - return r.setupRefCounter(initAddresses) + return r.setupRefCounter(initAddresses, stateManager) } if err = addRoutingTableName(); err != nil { @@ -104,7 +105,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb defer func() { if err != nil { - if cleanErr := r.CleanupRouting(); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -116,7 +117,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb if errors.Is(err, syscall.EOPNOTSUPP) { log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") setIsLegacy(true) - return r.setupRefCounter(initAddresses) + return r.setupRefCounter(initAddresses, stateManager) } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } @@ -128,9 +129,9 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { if isLegacy() { - return r.cleanupRefCounter() + return r.cleanupRefCounter(stateManager) } var result *multierror.Error diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index a2bbf35cf..0f8f2a341 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -13,15 +13,16 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return r.setupRefCounter(initAddresses) +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting() error { - return r.cleanupRefCounter() +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + return r.cleanupRefCounter(stateManager) } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 3f756788e..b1732a080 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -130,12 +131,12 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return r.setupRefCounter(initAddresses) +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting() error { - return r.cleanupRefCounter() +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + return r.cleanupRefCounter(stateManager) } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go new file mode 100644 index 000000000..a5a14f807 --- /dev/null +++ b/client/internal/statemanager/manager.go @@ -0,0 +1,298 @@ +package statemanager + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "reflect" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +// State interface defines the methods that all state types must implement +type State interface { + Name() string + Cleanup() error +} + +// Manager handles the persistence and management of various states +type Manager struct { + mu sync.Mutex + cancel context.CancelFunc + done chan struct{} + + filePath string + // holds the states that are registered with the manager and that are to be persisted + states map[string]State + // holds the state names that have been updated and need to be persisted with the next save + dirty map[string]struct{} + // holds the type information for each registered state + stateTypes map[string]reflect.Type +} + +// New creates a new Manager instance +func New(filePath string) *Manager { + return &Manager{ + filePath: filePath, + states: make(map[string]State), + dirty: make(map[string]struct{}), + stateTypes: make(map[string]reflect.Type), + } +} + +// Start starts the state manager periodic save routine +func (m *Manager) Start() { + if m == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + var ctx context.Context + ctx, m.cancel = context.WithCancel(context.Background()) + m.done = make(chan struct{}) + + go m.periodicStateSave(ctx) +} + +func (m *Manager) Stop(ctx context.Context) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + m.cancel() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } + } + + return nil +} + +// RegisterState registers a state with the manager but doesn't attempt to persist it. +// Pass an uninitialized state to register it. +func (m *Manager) RegisterState(state State) { + if m == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + name := state.Name() + m.states[name] = nil + m.stateTypes[name] = reflect.TypeOf(state).Elem() +} + +// GetState returns the state for the given type +func (m *Manager) GetState(state State) State { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.states[state.Name()] +} + +// UpdateState updates the state in the manager and marks it as dirty for the next save. +// The state will be replaced with the new one. +func (m *Manager) UpdateState(state State) error { + if m == nil { + return nil + } + + return m.setState(state.Name(), state) +} + +// DeleteState removes the state from the manager and marks it as dirty for the next save. +// Pass an uninitialized state to delete it. +func (m *Manager) DeleteState(state State) error { + if m == nil { + return nil + } + + return m.setState(state.Name(), nil) +} + +func (m *Manager) setState(name string, state State) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.states[name]; !exists { + return fmt.Errorf("state %s not registered", name) + } + + m.states[name] = state + m.dirty[name] = struct{}{} + + return nil +} + +func (m *Manager) periodicStateSave(ctx context.Context) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + defer close(m.done) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := m.PersistState(ctx); err != nil { + log.Errorf("failed to persist state: %v", err) + } + } + } +} + +// PersistState persists the states that have been updated since the last save. +func (m *Manager) PersistState(ctx context.Context) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.dirty) == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + done := make(chan error, 1) + + go func() { + data, err := json.MarshalIndent(m.states, "", " ") + if err != nil { + done <- fmt.Errorf("marshal states: %w", err) + return + } + + // nolint:gosec + if err := os.WriteFile(m.filePath, data, 0640); err != nil { + done <- fmt.Errorf("write state file: %w", err) + return + } + + done <- nil + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + if err != nil { + return err + } + } + + log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + + clear(m.dirty) + + return nil +} + +// loadState loads the existing state from the state file +func (m *Manager) loadState() error { + data, err := os.ReadFile(m.filePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + log.Debug("state file does not exist") + return nil + } + return fmt.Errorf("read state file: %w", err) + } + + var rawStates map[string]json.RawMessage + if err := json.Unmarshal(data, &rawStates); err != nil { + log.Warn("State file appears to be corrupted, attempting to delete it") + if err := os.Remove(m.filePath); err != nil { + log.Errorf("Failed to delete corrupted state file: %v", err) + } else { + log.Info("State file deleted") + } + return fmt.Errorf("unmarshal states: %w", err) + } + + var merr *multierror.Error + + for name, rawState := range rawStates { + stateType, ok := m.stateTypes[name] + if !ok { + merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) + continue + } + + if string(rawState) == "null" { + continue + } + + statePtr := reflect.New(stateType).Interface().(State) + if err := json.Unmarshal(rawState, statePtr); err != nil { + merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) + continue + } + + m.states[name] = statePtr + log.Debugf("loaded state: %s", name) + } + + return nberrors.FormatErrorOrNil(merr) +} + +// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. +// If the cleanup is successful, the state is marked for deletion. +func (m *Manager) PerformCleanup() error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.loadState(); err != nil { + log.Warnf("Failed to load state during cleanup: %v", err) + } + + var merr *multierror.Error + for name, state := range m.states { + if state == nil { + // If no state was found in the state file, we don't mark the state dirty nor return an error + continue + } + + log.Infof("client was not shut down properly, cleaning up %s", name) + if err := state.Cleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) + } else { + // mark for deletion on cleanup success + m.states[name] = nil + m.dirty[name] = struct{}{} + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go new file mode 100644 index 000000000..96d6a9f12 --- /dev/null +++ b/client/internal/statemanager/path.go @@ -0,0 +1,35 @@ +package statemanager + +import ( + "os" + "path/filepath" + "runtime" + + log "github.com/sirupsen/logrus" +) + +// GetDefaultStatePath returns the path to the state file based on the operating system +// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +func GetDefaultStatePath() string { + var path string + + switch runtime.GOOS { + case "windows": + path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") + case "darwin", "linux": + path = "/var/lib/netbird/state.json" + case "freebsd", "openbsd", "netbsd", "dragonfly": + path = "/var/db/netbird/state.json" + // ios/android don't need state + default: + return "" + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) + return "" + } + + return path +} diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go deleted file mode 100644 index 369ba99db..000000000 --- a/client/internal/wgproxy/factory_linux.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build !android - -package wgproxy - -import ( - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" - "github.com/netbirdio/netbird/client/internal/wgproxy/usp" -) - -type Factory struct { - wgPort int - ebpfProxy *ebpf.WGEBPFProxy -} - -func NewFactory(userspace bool, wgPort int) *Factory { - f := &Factory{wgPort: wgPort} - - if userspace { - return f - } - - ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) - err := ebpfProxy.Listen() - if err != nil { - log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) - return f - } - - f.ebpfProxy = ebpfProxy - return f -} - -func (w *Factory) GetProxy() Proxy { - if w.ebpfProxy != nil { - p := &ebpf.ProxyWrapper{ - WgeBPFProxy: w.ebpfProxy, - } - return p - } - return usp.NewWGUserSpaceProxy(w.wgPort) -} - -func (w *Factory) Free() error { - if w.ebpfProxy == nil { - return nil - } - return w.ebpfProxy.Free() -} diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go deleted file mode 100644 index f930b09b3..000000000 --- a/client/internal/wgproxy/factory_nonlinux.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !linux || android - -package wgproxy - -import "github.com/netbirdio/netbird/client/internal/wgproxy/usp" - -type Factory struct { - wgPort int -} - -func NewFactory(_ bool, wgPort int) *Factory { - return &Factory{wgPort: wgPort} -} - -func (w *Factory) GetProxy() Proxy { - return usp.NewWGUserSpaceProxy(w.wgPort) -} - -func (w *Factory) Free() error { - return nil -} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go deleted file mode 100644 index 558121cdd..000000000 --- a/client/internal/wgproxy/proxy.go +++ /dev/null @@ -1,15 +0,0 @@ -package wgproxy - -import ( - "context" - "net" -) - -// Proxy is a transfer layer between the relayed connection and the WireGuard -type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) error - EndpointAddr() *net.UDPAddr - Work() - Pause() - CloseConn() error -} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index dc13706bf..9d65bdbe0 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -138,12 +138,12 @@ func (c *Client) Stop() { c.ctxCancel() } -// ÏSetTraceLogLevel configure the logger to trace level +// SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) } -// getStatusDetails return with the list of the PeerInfos +// GetStatusDetails return with the list of the PeerInfos func (c *Client) GetStatusDetails() *StatusDetails { fullStatus := c.recorder.GetFullStatus() diff --git a/client/server/server.go b/client/server/server.go index 0a4c18131..a03322081 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -39,6 +39,8 @@ const ( defaultMaxRetryInterval = 60 * time.Minute defaultMaxRetryTime = 14 * 24 * time.Hour defaultRetryMultiplier = 1.7 + + errRestoreResidualState = "failed to restore residual state: %v" ) // Server for service control. @@ -95,6 +97,10 @@ func (s *Server) Start() error { defer s.mutex.Unlock() state := internal.CtxGetState(s.rootCtx) + if err := restoreResidualState(s.rootCtx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + // if current state contains any error, return it // in all other cases we can continue execution only if status is idle and up command was // not in the progress or already successfully established connection. @@ -292,6 +298,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() + if err := restoreResidualState(ctx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + state := internal.CtxGetState(ctx) defer func() { status, err := state.Status() @@ -549,6 +559,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.mutex.Lock() defer s.mutex.Unlock() + if err := restoreResidualState(callerCtx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + state := internal.CtxGetState(s.rootCtx) // if current state contains any error, return it diff --git a/client/server/state.go b/client/server/state.go new file mode 100644 index 000000000..509782e86 --- /dev/null +++ b/client/server/state.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-multierror" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. +// Otherwise, we might not be able to connect to the management server to retrieve new config. +func restoreResidualState(ctx context.Context) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + mgr := statemanager.New(path) + + // register the states we are interested in restoring + registerStates(mgr) + + var merr *multierror.Error + if err := mgr.PerformCleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) + } + + // persist state regardless of cleanup outcome. It could've succeeded partially + if err := mgr.PersistState(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/server/state_generic.go b/client/server/state_generic.go new file mode 100644 index 000000000..e6c7bdd44 --- /dev/null +++ b/client/server/state_generic.go @@ -0,0 +1,14 @@ +//go:build !linux || android + +package server + +import ( + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func registerStates(mgr *statemanager.Manager) { + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) +} diff --git a/client/server/state_linux.go b/client/server/state_linux.go new file mode 100644 index 000000000..087628907 --- /dev/null +++ b/client/server/state_linux.go @@ -0,0 +1,18 @@ +//go:build !android + +package server + +import ( + "github.com/netbirdio/netbird/client/firewall/iptables" + "github.com/netbirdio/netbird/client/firewall/nftables" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func registerStates(mgr *statemanager.Manager) { + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) + mgr.RegisterState(&nftables.ShutdownState{}) + mgr.RegisterState(&iptables.ShutdownState{}) +} diff --git a/go.mod b/go.mod index e7e3c17a6..a6b83794d 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 + github.com/r3labs/diff/v3 v3.0.1 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -210,6 +211,8 @@ require ( github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect + github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect diff --git a/go.sum b/go.sum index e9bc318d6..412542d5e 100644 --- a/go.sum +++ b/go.sum @@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= +github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg= +github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 16b2364fb..0b2b65142 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -873,7 +873,7 @@ services: zitadel: restart: 'always' networks: [netbird] - image: 'ghcr.io/zitadel/zitadel:v2.54.10' + image: 'ghcr.io/zitadel/zitadel:v2.64.1' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' env_file: - ./zitadel.env diff --git a/management/server/account.go b/management/server/account.go index 426d94bf4..e30e30759 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -157,6 +157,7 @@ type AccountManager interface { FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) + DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error } type DefaultAccountManager struct { @@ -1131,8 +1132,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return nil, status.NewUnauthorizedToViewAccountSettingError() } halfYearLimit := 180 * 24 * time.Hour @@ -1144,28 +1149,19 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - var oldSettings *Settings - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthUpdate, accountID) - if err != nil { - return fmt.Errorf("failed to get account settings: %w", err) - } - - if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil { - return fmt.Errorf("failed to validate extra settings: %w", err) - } - - if err = transaction.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil { - return fmt.Errorf("failed to update account settings: %w", err) - } - - return nil - }) + oldSettings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } + if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil { + return nil, err + } + + if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil { + return nil, err + } + am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) @@ -2026,10 +2022,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer func() { - if unlockPeer != nil { - unlockPeer() + if unlockAccount != nil { + unlockAccount() } }() @@ -2038,12 +2034,12 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st var hasChanges bool var user *User err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err = transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2101,8 +2097,8 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error incrementing network serial: %w", err) } } - unlockPeer() - unlockPeer = nil + unlockAccount() + unlockAccount = nil return nil }) @@ -2115,7 +2111,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2128,7 +2124,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2141,13 +2137,24 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + removedGroupAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, removeOldGroups) if err != nil { - return fmt.Errorf("error getting account: %w", err) + return err } - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + newGroupsAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, addNewGroups) + if err != nil { + return err + } + + if removedGroupAffectsPeers || newGroupsAffectsPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } } return nil @@ -2249,7 +2256,8 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context } if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + log.WithContext(ctx).Debugf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + return "", status.NewUserNotPartOfAccountError() } accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) @@ -2443,8 +2451,12 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return nil, err } - if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewAccountSettingError() } return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) @@ -2467,8 +2479,8 @@ func addAllGroup(account *Account) error { defaultPolicy := &Policy{ ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: DefaultPolicyName, + Description: DefaultPolicyDescription, Enabled: true, Rules: []*PolicyRule{ { diff --git a/management/server/account_test.go b/management/server/account_test.go index 565c3129b..d0238fe0a 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1011,7 +1011,6 @@ func TestAccountManager_AddPeer(t *testing.T) { return } expectedPeerKey := key.PublicKey().String() - expectedSetupKey := setupKey.Key peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, @@ -1036,10 +1035,6 @@ func TestAccountManager_AddPeer(t *testing.T) { t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String()) } - if peer.SetupKey != expectedSetupKey { - t.Errorf("expecting just added peer to have SetupKey = %s, got %s", expectedSetupKey, peer.SetupKey) - } - if account.Network.CurrentSerial() != 1 { t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial()) } @@ -1123,68 +1118,17 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } -func TestAccountManager_NetworkUpdates(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - return - } - - userID := "account_creator" - - account, err := createAccount(manager, "test_account", userID, "") - if err != nil { - t.Fatal(err) - } - - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) - if err != nil { - t.Fatal("error creating setup key") - return - } - - if account.Network.Serial != 0 { - t.Errorf("expecting account network to have an initial Serial=0") - return - } - - getPeer := func() *nbpeer.Peer { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Fatal(err) - return nil - } - expectedPeerKey := key.PublicKey().String() - - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ - Key: expectedPeerKey, - Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) - if err != nil { - t.Fatalf("expecting peer1 to be added, got failure %v", err) - return nil - } - - return peer - } - - peer1 := getPeer() - peer2 := getPeer() - peer3 := getPeer() - - account, err = manager.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - return - } - - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) +func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := group.Group{ - ID: "group-id", + ID: "groupA", Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + Peers: []string{}, + } + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return } policy := Policy{ @@ -1195,8 +1139,89 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { { ID: xid.New().String(), Enabled: true, - Sources: []string{"group-id"}, - Destinations: []string{"group-id"}, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + group.Peers = []string{peer1.ID, peer2.ID, peer3.ID} + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + manager, account, peer1, _, _ := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + manager, account, peer1, peer2, _ := setupNetworkMapTest(t) + + group := group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + } + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return + } + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, Bidirectional: true, Action: PolicyTrafficActionAccept, }, @@ -1204,107 +1229,138 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } wg := sync.WaitGroup{} - t.Run("save group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + wg.Add(1) + go func() { + defer wg.Done() - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { - t.Errorf("save group: %v", err) - return + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("delete default rule: %v", err) + return + } - t.Run("delete policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + wg.Wait() +} - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() +func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + manager, account, peer1, _, peer3 := setupNetworkMapTest(t) - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return + group := group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer3.ID}, + } + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return + } + + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("save policy: %v", err) + return + } + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 1 { + t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) + if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { + t.Errorf("delete peer: %v", err) + return + } - t.Run("save policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + wg.Wait() +} - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() +func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { - t.Errorf("delete default rule: %v", err) - return + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + group := group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + } + + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("save policy: %v", err) + return + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) - t.Run("delete peer update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + // clean policy is pre requirement for delete group + if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 1 { - t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) - } - }() + if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + t.Errorf("delete group: %v", err) + return + } - if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { - t.Errorf("delete peer: %v", err) - return - } - - wg.Wait() - }) - - t.Run("delete group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - // clean policy is pre requirement for delete group - _ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID) - - if err := manager.DeleteGroup(context.Background(), account.Id, userID, group.ID); err != nil { - t.Errorf("delete group: %v", err) - return - } - - wg.Wait() - }) + wg.Wait() } func TestAccountManager_DeletePeer(t *testing.T) { @@ -2300,7 +2356,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: false, }, LoginExpirationEnabled: true, - SetupKey: "key", }, "peer-2": { Status: &nbpeer.PeerStatus{ @@ -2308,7 +2363,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: false, }, LoginExpirationEnabled: true, - SetupKey: "key", }, }, expiration: time.Second, @@ -2462,7 +2516,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { LoginExpired: false, }, InactivityExpirationEnabled: true, - SetupKey: "key", }, "peer-2": { Status: &nbpeer.PeerStatus{ @@ -2470,7 +2523,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { LoginExpired: false, }, InactivityExpirationEnabled: true, - SetupKey: "key", }, }, expiration: time.Second, @@ -2749,3 +2801,73 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { return true } } + +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { + t.Helper() + + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + account, err := createAccount(manager, "test_account", userID, "") + if err != nil { + t.Fatal(err) + } + + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + if err != nil { + t.Fatal("error creating setup key") + } + + getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + expectedPeerKey := key.PublicKey().String() + + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + Status: &nbpeer.PeerStatus{ + Connected: true, + LastSeen: time.Now().UTC(), + }, + }) + if err != nil { + t.Fatalf("expecting peer to be added, got failure %v", err) + } + + return peer + } + + peer1 := getPeer(manager, setupKey) + peer2 := getPeer(manager, setupKey) + peer3 := getPeer(manager, setupKey) + + return manager, account, peer1, peer2, peer3 +} + +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + case <-time.After(500 * time.Millisecond): + t.Error("Timed out waiting for update message") + } +} diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 188494241..603260dbc 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -146,6 +146,8 @@ const ( AccountPeerInactivityExpirationEnabled Activity = 65 AccountPeerInactivityExpirationDisabled Activity = 66 AccountPeerInactivityExpirationDurationUpdated Activity = 67 + + SetupKeyDeleted Activity = 68 ) var activityMap = map[Activity]Code{ @@ -219,6 +221,7 @@ var activityMap = map[Activity]Code{ AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, + SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/differs/netip.go b/management/server/differs/netip.go new file mode 100644 index 000000000..de4aa334c --- /dev/null +++ b/management/server/differs/netip.go @@ -0,0 +1,82 @@ +package differs + +import ( + "fmt" + "net/netip" + "reflect" + + "github.com/r3labs/diff/v3" +) + +// NetIPAddr is a custom differ for netip.Addr +type NetIPAddr struct { + DiffFunc func(path []string, a, b reflect.Value, p interface{}) error +} + +func (differ NetIPAddr) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(netip.Addr{})) +} + +func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { + if a.Kind() == reflect.Invalid { + cl.Add(diff.CREATE, path, nil, b.Interface()) + return nil + } + + if b.Kind() == reflect.Invalid { + cl.Add(diff.DELETE, path, a.Interface(), nil) + return nil + } + + fromAddr, ok1 := a.Interface().(netip.Addr) + toAddr, ok2 := b.Interface().(netip.Addr) + if !ok1 || !ok2 { + return fmt.Errorf("invalid type for netip.Addr") + } + + if fromAddr.String() != toAddr.String() { + cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String()) + } + + return nil +} + +func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + differ.DiffFunc = dfunc //nolint +} + +// NetIPPrefix is a custom differ for netip.Prefix +type NetIPPrefix struct { + DiffFunc func(path []string, a, b reflect.Value, p interface{}) error +} + +func (differ NetIPPrefix) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{})) +} + +func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { + if a.Kind() == reflect.Invalid { + cl.Add(diff.CREATE, path, nil, b.Interface()) + return nil + } + if b.Kind() == reflect.Invalid { + cl.Add(diff.DELETE, path, a.Interface(), nil) + return nil + } + + fromPrefix, ok1 := a.Interface().(netip.Prefix) + toPrefix, ok2 := b.Interface().(netip.Prefix) + if !ok1 || !ok2 { + return fmt.Errorf("invalid type for netip.Addr") + } + + if fromPrefix.String() != toPrefix.String() { + cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String()) + } + + return nil +} + +func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + differ.DiffFunc = dfunc //nolint +} diff --git a/management/server/dns.go b/management/server/dns.go index 85cca221a..ace6c680d 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -86,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewDNSSettingsError() } return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) @@ -104,8 +108,12 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return err } - if !user.HasAdminPower() || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return status.NewUnauthorizedToViewDNSSettingsError() } oldSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) @@ -125,6 +133,14 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } } + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + + updateAccountPeers, err := am.areDNSSettingChangesAffectPeers(ctx, accountID, addedGroups, removedGroups) + if err != nil { + return fmt.Errorf("failed to check if dns settings changes affect peers: %w", err) + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("failed to increment network serial: %w", err) @@ -145,7 +161,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID groupMap[g.ID] = g } - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) for _, id := range addedGroups { group, ok := groupMap[id] if ok { @@ -154,7 +169,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } } - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) for _, id := range removedGroups { group, ok := groupMap[id] if ok { @@ -163,15 +177,31 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } +// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. +func (am *DefaultAccountManager) areDNSSettingChangesAffectPeers(ctx context.Context, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := am.anyGroupHasPeers(ctx, accountID, addedGroups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return am.anyGroupHasPeers(ctx, accountID, removedGroups) +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ diff --git a/management/server/dns_test.go b/management/server/dns_test.go index c7f435b68..c675fc12c 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -6,9 +6,11 @@ import ( "net/netip" "reflect" "testing" + "time" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -476,3 +478,145 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Cache should contain name server group 'group2'") } } + +func TestDNSAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates + t.Run("saving dns setting with unused groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + // Saving DNS settings with groups that have peers should update account peers and send peer update + t.Run("saving dns setting with used groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA", "groupB"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged DNS settings with used groups should update account peers and not send peer update + // since there is no change in the network map + t.Run("saving unchanged dns setting with used groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA", "groupB"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates + t.Run("removing group with no peers from dns settings", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing group with peers from DNS settings should trigger updates to account peers and send peer updates + t.Run("removing group with peers from dns settings", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/group.go b/management/server/group.go index c902830c0..2584be24a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -37,12 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked { - return status.Errorf(status.PermissionDenied, "access to groups is blocked for users") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if user.IsRegularUser() && settings.RegularUsersViewBlocked { + return status.NewUnauthorizedToViewGroupsError() } return nil @@ -84,7 +84,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return status.NewUserNotPartOfAccountError() } var ( @@ -128,6 +128,16 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user eventsToStore = append(eventsToStore, events...) } + newGroupIDs := make([]string, 0, len(newGroups)) + for _, newGroup := range newGroups { + newGroupIDs = append(newGroupIDs, newGroup.ID) + } + + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs) + if err != nil { + return err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf(errNetworkSerialIncrementFmt, err) @@ -146,11 +156,13 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user storeEvent() } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -233,7 +245,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return status.NewUserNotPartOfAccountError() } group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) @@ -265,12 +277,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) - return nil } @@ -282,7 +288,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return status.NewUserNotPartOfAccountError() } var ( @@ -324,12 +330,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta()) } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) - return allErrors } @@ -351,6 +351,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr group.Peers = append(group.Peers, peerID) } + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) + if err != nil { + return err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf(errNetworkSerialIncrementFmt, err) @@ -365,11 +370,13 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -394,6 +401,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return nil } + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) + if err != nil { + return err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf(errNetworkSerialIncrementFmt, err) @@ -408,11 +420,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -429,23 +443,23 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group } } - if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.ID, group.AccountID); isLinked { + if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.ID, group.AccountID); isLinked { + if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.ID, group.AccountID); isLinked { + if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.ID, group.AccountID); isLinked { + if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.ID, group.AccountID); isLinked { + if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } @@ -473,7 +487,7 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, groupID string, accountID string) (bool, *route.Route) { +func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) { routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) @@ -490,7 +504,7 @@ func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, group } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, groupID string, accountID string) (bool, *Policy) { +func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) { policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) @@ -508,7 +522,7 @@ func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, grou } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, groupID string, accountID string) (bool, *nbdns.NameServerGroup) { +func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) @@ -527,7 +541,7 @@ func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, groupID } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, groupID string, accountID string) (bool, *SetupKey) { +func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) { setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) @@ -543,7 +557,7 @@ func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, gr } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, groupID string, accountID string) (bool, *User) { +func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) { users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) @@ -557,3 +571,48 @@ func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, groupI } return false, nil } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func (am *DefaultAccountManager) anyGroupHasPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + for _, groupID := range groupIDs { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + return false, err + } + + if group.HasPeers() { + return true, nil + } + } + + return false, nil +} + +// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { + return false, nil + } + + dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, groupID := range groupIDs { + if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { + return true, nil + } + if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked { + return true, nil + } + if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked { + return true, nil + } + if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked { + return true, nil + } + } + + return false, nil +} diff --git a/management/server/group/group.go b/management/server/group/group.go index 79dfd995c..d293e1afc 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -44,3 +44,8 @@ func (g *Group) Copy() *Group { copy(group.Peers, g.Peers) return group } + +// HasPeers checks if the group has any peers. +func (g *Group) HasPeers() bool { + return len(g.Peers) > 0 +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 89b68ad6c..1e59b74ef 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -4,13 +4,16 @@ import ( "context" "errors" "fmt" + "net/netip" "testing" + "time" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -384,3 +387,312 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A } return am, acc, nil } + +func TestGroupAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1.ID, peer3.ID}, + }, + { + ID: "groupD", + Name: "GroupD", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Saving a group that is not linked to any resource should not update account peers + t.Run("saving unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Adding a peer to a group that is not linked to any resource should not update account peers + // and not send peer update + t.Run("adding peer to unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupAddPeer(context.Background(), account.Id, "groupB", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing a peer from a group that is not linked to any resource should not update account peers + // and not send peer update + t.Run("removing peer from unliked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupDeletePeer(context.Background(), account.Id, "groupB", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting group should not update account peers and not send peer update + t.Run("deleting group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupB") + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // adding a group to policy + err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }, false) + assert.NoError(t, err) + + // Saving a group linked to policy should update account peers and send peer update + t.Run("saving linked group to policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving an unchanged group should trigger account peers update and not send peer update + // since there is no change in the network map + t.Run("saving unchanged group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // adding peer to a used group should update account peers and send peer update + t.Run("adding peer to linked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupAddPeer(context.Background(), account.Id, "groupA", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // removing peer from a linked group should update account peers and send peer update + t.Run("removing peer from linked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupDeletePeer(context.Background(), account.Id, "groupA", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to name server group should update account peers and send peer update + t.Run("saving group linked to name server group", func(t *testing.T) { + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupC"}, + true, nil, true, userID, false, + ) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1.ID, peer3.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to route should update account peers and send peer update + t.Run("saving group linked to route", func(t *testing.T) { + newRoute := route.Route{ + ID: "route", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupA"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupC"}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to dns settings should update account peers and send peer update + t.Run("saving group linked to dns settings", func(t *testing.T) { + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupD"}, + }) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupD", + Name: "GroupD", + Peers: []string{peer1.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 5ff07c821..8959932d1 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -137,13 +137,15 @@ func toAccountResponse(accountID string, settings *server.Settings) *api.Account } apiSettings := api.AccountSettings{ - PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()), - PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled, - GroupsPropagationEnabled: &settings.GroupsPropagationEnabled, - JwtGroupsEnabled: &settings.JWTGroupsEnabled, - JwtGroupsClaimName: &settings.JWTGroupsClaimName, - JwtAllowGroups: &jwtAllowGroups, - RegularUsersViewBlocked: settings.RegularUsersViewBlocked, + PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()), + PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled, + PeerInactivityExpiration: int(settings.PeerInactivityExpiration.Seconds()), + PeerInactivityExpirationEnabled: settings.PeerInactivityExpirationEnabled, + GroupsPropagationEnabled: &settings.GroupsPropagationEnabled, + JwtGroupsEnabled: &settings.JWTGroupsEnabled, + JwtGroupsClaimName: &settings.JWTGroupsClaimName, + JwtAllowGroups: &jwtAllowGroups, + RegularUsersViewBlocked: settings.RegularUsersViewBlocked, } if settings.Extra != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 9d5148248..9b4592ccf 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -530,10 +530,9 @@ components: type: string example: reusable expires_in: - description: Expiration time in seconds + description: Expiration time in seconds, 0 will mean the key never expires type: integer - minimum: 86400 - maximum: 31536000 + minimum: 0 example: 86400 revoked: description: Setup key revocation status @@ -2018,6 +2017,32 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Setup Key + description: Delete a Setup Key + tags: [ Setup Keys ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: keyId + required: true + schema: + type: string + description: The unique identifier of a setup key + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/groups: get: summary: List all Groups diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index e2870d5d8..c1ef1ba21 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1101,7 +1101,7 @@ type SetupKeyRequest struct { // Ephemeral Indicate that the peer will be ephemeral or not Ephemeral *bool `json:"ephemeral,omitempty"` - // ExpiresIn Expiration time in seconds + // ExpiresIn Expiration time in seconds, 0 will mean the key never expires ExpiresIn int `json:"expires_in"` // Name Setup Key name diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3f8a8554d..c3928bff6 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -141,6 +141,7 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() { apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") + apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS") } func (apiHandler *apiHandler) addPoliciesEndpoint() { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index f933eee14..dd49c03b8 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -13,12 +13,13 @@ import ( "time" "github.com/gorilla/mux" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "golang.org/x/exp/maps" "github.com/stretchr/testify/assert" @@ -168,7 +169,6 @@ func TestGetPeers(t *testing.T) { peer := &nbpeer.Peer{ ID: testPeerID, Key: "key", - SetupKey: "setupkey", IP: net.ParseIP("100.64.0.1"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "PeerName", diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8514f0b55..31859f59b 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -61,10 +61,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request expiresIn := time.Duration(req.ExpiresIn) * time.Second - day := time.Hour * 24 - year := day * 365 - if expiresIn < day || expiresIn > year { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w) + if expiresIn < 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn can not be in the past"), w) return } @@ -76,6 +74,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { @@ -83,7 +82,11 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request return } - writeSuccess(r.Context(), w, setupKey) + apiSetupKeys := toResponseBody(setupKey) + // for the creation we need to send the plain key + apiSetupKeys.Key = setupKey.Key + + util.WriteJSONObject(r.Context(), w, apiSetupKeys) } // GetSetupKey is a GET request to get a SetupKey by ID @@ -98,7 +101,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) return } @@ -123,7 +126,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) return } @@ -181,6 +184,30 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, apiSetupKeys) } +func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + keyID := vars["keyId"] + if len(keyID) == 0 { + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) + return + } + + err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, emptyObject{}) +} + func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -206,7 +233,7 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey { return &api.SetupKey{ Id: key.Id, - Key: key.Key, + Key: key.KeySecret, Name: key.Name, Expires: key.ExpiresAt, Type: string(key.Type), diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index 2d15287af..09256d0ea 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -67,6 +67,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { return []*server.SetupKey{defaultKey}, nil }, + + DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error { + if keyID == defaultKey.Id { + return nil + } + return status.Errorf(status.NotFound, "key %s not found", keyID) + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { @@ -81,18 +88,21 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } func TestSetupKeysHandlers(t *testing.T) { - defaultSetupKey := server.GenerateDefaultSetupKey() + defaultSetupKey, _ := server.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID adminUser := server.NewAdminUser("test_user") - newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, + newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, server.SetupKeyUnlimitedUsage, true) + newSetupKey.Key = plainKey updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Revoked = true + expectedNewKey := toResponseBody(newSetupKey) + expectedNewKey.Key = plainKey tt := []struct { name string requestType string @@ -134,7 +144,7 @@ func TestSetupKeysHandlers(t *testing.T) { []byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))), expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(newSetupKey), + expectedSetupKey: expectedNewKey, }, { name: "Update Setup Key", @@ -150,6 +160,14 @@ func TestSetupKeysHandlers(t *testing.T) { expectedBody: true, expectedSetupKey: toResponseBody(updatedDefaultSetupKey), }, + { + name: "Delete Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/" + defaultSetupKey.Id, + requestBody: bytes.NewBuffer([]byte("")), + expectedStatus: http.StatusOK, + expectedBody: false, + }, } handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser) @@ -164,6 +182,7 @@ func TestSetupKeysHandlers(t *testing.T) { router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index bdf744d21..843fa575e 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -267,7 +267,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { peersSSHEnabled++ } - if peer.SetupKey == "" { + if peer.UserID != "" { userPeers++ } diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 4c8baea5e..6f12d94b4 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -2,13 +2,16 @@ package migration import ( "context" + "crypto/sha256" "database/sql" + b64 "encoding/base64" "encoding/gob" "encoding/json" "errors" "fmt" "net" "strings" + "unicode/utf8" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -205,3 +208,90 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi return nil } + +func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error { + oldColumnName := "key" + newColumnName := "key_secret" + + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if err := db.Transaction(func(tx *gorm.DB) error { + if !tx.Migrator().HasColumn(&model, newColumnName) { + log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", newColumnName, tableName) + if err := tx.Migrator().AddColumn(&model, newColumnName); err != nil { + return fmt.Errorf("add column %s: %w", newColumnName, err) + } + } + + var rows []map[string]any + if err := tx.Table(tableName). + Select("id", oldColumnName, newColumnName). + Where(newColumnName + " IS NULL OR " + newColumnName + " = ''"). + Where("SUBSTR(" + oldColumnName + ", 9, 1) = '-'"). + Find(&rows).Error; err != nil { + return fmt.Errorf("find rows with empty secret key and matching pattern: %w", err) + } + + if len(rows) == 0 { + log.WithContext(ctx).Infof("No plain setup keys found in table %s, no migration needed", tableName) + return nil + } + + for _, row := range rows { + var plainKey string + if columnValue := row[oldColumnName]; columnValue != nil { + value, ok := columnValue.(string) + if !ok { + return fmt.Errorf("type assertion failed") + } + plainKey = value + } + + secretKey := hiddenKey(plainKey, 4) + + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, secretKey).Error; err != nil { + return fmt.Errorf("update row with secret key: %w", err) + } + + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(oldColumnName, encodedHashedKey).Error; err != nil { + return fmt.Errorf("update row with hashed key: %w", err) + } + } + + if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil { + log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err) + } + + return nil + }); err != nil { + return err + } + + log.Printf("Migration of plain setup key to hashed setup key completed") + return nil +} + +// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. +// E.g., "831F6*******************************" +func hiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) +} diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 5a1926641..51358c7ad 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -160,3 +160,72 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr) assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged") } + +func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} + +func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + KeySecret: "EEFDA****", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} + +func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 08bd15e10..2f91a0478 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -110,6 +110,14 @@ type MockAccountManager struct { GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) + DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error +} + +func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { + if am.DeleteSetupKeyFunc != nil { + return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) + } + return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 5d2f9d90f..48ff35987 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -26,8 +26,12 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewNSGroupsError() } return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) @@ -41,7 +45,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco } if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "no permission to create nameserver for this account") + return nil, status.NewUserNotPartOfAccountError() } newNSGroup := &nbdns.NameServerGroup{ @@ -62,6 +66,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return nil, err } + updateAccountPeers, err := am.anyGroupHasPeers(ctx, accountID, newNSGroup.Groups) + if err != nil { + return nil, err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("failed to increment network serial: %w", err) @@ -79,11 +88,13 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("error getting account: %w", err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return newNSGroup.Copy(), nil } @@ -100,15 +111,19 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account") + return status.NewUserNotPartOfAccountError() } - _, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + oldNSGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) if err != nil { return err } - err = am.validateNameServerGroup(ctx, accountID, nsGroupToSave) + if err = am.validateNameServerGroup(ctx, accountID, nsGroupToSave); err != nil { + return err + } + + updateAccountPeers, err := am.areNameServerGroupChangesAffectPeers(ctx, nsGroupToSave, oldNSGroup) if err != nil { return err } @@ -130,11 +145,13 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -147,7 +164,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account") + return status.NewUserNotPartOfAccountError() } nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) @@ -155,6 +172,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } + updateAccountPeers, err := am.anyGroupHasPeers(ctx, accountID, nsGroup.Groups) + if err != nil { + return err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("failed to increment network serial: %w", err) @@ -172,11 +194,13 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -188,8 +212,12 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewNSGroupsError() } return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) @@ -229,6 +257,24 @@ func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, ac return nil } +// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. +func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(ctx context.Context, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { + if !newNSGroup.Enabled && !oldNSGroup.Enabled { + return false, nil + } + + hasPeers, err := am.anyGroupHasPeers(ctx, newNSGroup.AccountID, newNSGroup.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return am.anyGroupHasPeers(ctx, oldNSGroup.AccountID, oldNSGroup.Groups) +} + func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { if !primary && len(domains) == 0 { return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 8a3fe6eb0..96637cd39 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -4,7 +4,9 @@ import ( "context" "net/netip" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" @@ -935,3 +937,179 @@ func TestValidateDomain(t *testing.T) { } } + +func TestNameServerAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + var newNameServerGroupA *nbdns.NameServerGroup + var newNameServerGroupB *nbdns.NameServerGroup + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Creating a nameserver group with a distribution group no peers should not update account peers + // and not send peer update + t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupA, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroupA", "nsGroupA", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // saving a nameserver group with a distribution group with no peers should not update account peers + // and not send peer update + t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupA) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Creating a nameserver group with a distribution group no peers should update account peers and send peer update + t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroupB", "nsGroupB", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupB"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // saving a nameserver group with a distribution group with peers should update account peers and send peer update + t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB.NameServers = []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + } + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // saving unchanged nameserver group should update account peers and not send peer update + t.Run("saving unchanged nameserver group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB.NameServers = []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + } + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting a nameserver group should update account peers and send peer update + t.Run("deleting nameserver group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeleteNameServerGroup(context.Background(), account.Id, newNameServerGroupB.ID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/network.go b/management/server/network.go index a5b188b46..8fb6a8b3c 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -41,9 +41,9 @@ type Network struct { Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. - Serial uint64 + Serial uint64 `diff:"-"` - mu sync.Mutex `json:"-" gorm:"-"` + mu sync.Mutex `json:"-" gorm:"-" diff:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 diff --git a/management/server/peer.go b/management/server/peer.go index f49c9609f..c58e7b225 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "net" "slices" @@ -57,7 +59,7 @@ func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userI } if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + return nil, status.NewUserNotPartOfAccountError() } return am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) @@ -72,7 +74,7 @@ func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, us } if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return nil, status.NewUserNotPartOfAccountError() } settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) @@ -87,9 +89,7 @@ func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, us peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) - regularUser := !user.HasAdminPower() && !user.IsServiceUser - - if regularUser && settings.RegularUsersViewBlocked { + if user.IsRegularUser() && settings.RegularUsersViewBlocked { return peers, nil } @@ -98,7 +98,7 @@ func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, us return nil, err } for _, peer := range accountPeers { - if regularUser && user.Id != peer.UserID { + if user.IsRegularUser() && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin continue } @@ -107,7 +107,7 @@ func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, us peersMap[peer.ID] = p } - if !regularUser { + if user.IsAdminOrServiceUser() { return peers, nil } @@ -215,7 +215,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return nil, status.NewUserNotPartOfAccountError() } peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID) @@ -247,7 +247,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } - if peer.Name != update.Name { + peerLabelUpdated := peer.Name != update.Name + + if peerLabelUpdated { peer.Name = update.Name existingLabels, err := am.getPeerDNSLabels(ctx, accountID) @@ -306,11 +308,13 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, err } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, fmt.Errorf(errGetAccountFmt, err) + if peerLabelUpdated { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, fmt.Errorf(errGetAccountFmt, err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return peer, nil } @@ -347,6 +351,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, accountID stri FirewallRulesIsEmpty: true, }, }, + NetworkMap: &NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -364,7 +369,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return status.NewUserNotPartOfAccountError() } peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) @@ -372,15 +377,22 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID) + if err != nil { + return err + } + if err = am.deletePeers(ctx, accountID, userID, []*nbpeer.Peer{peer}); err != nil { return err } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -434,6 +446,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } upperKey := strings.ToUpper(setupKey) + hashedKey := sha256.Sum256([]byte(upperKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) var accountID string var err error addedByUser := false @@ -441,7 +455,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s addedByUser = true accountID, err = am.Store.GetAccountIDByUserID(userID) } else { - accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey) + accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) } if err != nil { return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") @@ -470,9 +484,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer + var groupsToAdd []string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - var groupsToAdd []string var setupKeyID string var setupKeyName string var ephemeral bool @@ -486,7 +500,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s opEvent.Activity = activity.PeerAddedByUser } else { // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey) + sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey) if err != nil { return fmt.Errorf("failed to get setup key: %w", err) } @@ -527,7 +541,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s ID: xid.New().String(), AccountID: accountID, Key: peer.Key, - SetupKey: upperKey, IP: freeIP, Meta: peer.Meta, Name: peer.Meta.Hostname, @@ -619,12 +632,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s unlock() unlock = nil + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, groupsToAdd) + if err != nil { + return nil, nil, nil, err + } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error getting account: %w", err) } - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) if err != nil { @@ -979,7 +999,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, } if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return nil, status.NewUserNotPartOfAccountError() } settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) @@ -987,7 +1007,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } - if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked { + if user.IsRegularUser() && settings.RegularUsersViewBlocked { return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) } @@ -1074,7 +1094,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1266,3 +1286,13 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { } return labelMap } + +// IsPeerInActiveGroup checks if the given peer is part of a group that is used +// in an active DNS, route, or ACL configuration. +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) { + peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID) + if err != nil { + return false, err + } + return am.areGroupChangesAffectPeers(ctx, accountID, peerGroupIDs) +} diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 9a53459a8..82e0acf3a 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -4,6 +4,7 @@ import ( "net" "net/netip" "slices" + "sort" "time" ) @@ -16,38 +17,36 @@ type Peer struct { AccountID string `json:"-" gorm:"index"` // WireGuard public key Key string `gorm:"index"` - // A setup key this peer was registered with - SetupKey string // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data - Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"` // The user ID that registered the peer - UserID string + UserID string `diff:"-"` // SSHKey is a public SSH key of the peer SSHKey string // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin - LoginExpirationEnabled bool + LoginExpirationEnabled bool `diff:"-"` - InactivityExpirationEnabled bool + InactivityExpirationEnabled bool `diff:"-"` // LastLogin the time when peer performed last login operation - LastLogin time.Time + LastLogin time.Time `diff:"-"` // CreatedAt records the time the peer was created - CreatedAt time.Time + CreatedAt time.Time `diff:"-"` // Indicate ephemeral peer attribute - Ephemeral bool + Ephemeral bool `diff:"-"` // Geo location based on connection IP - Location Location `gorm:"embedded;embeddedPrefix:location_"` + Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` } type PeerStatus struct { //nolint:revive @@ -109,6 +108,12 @@ type PeerSystemMeta struct { //nolint:revive } func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { + sort.Slice(p.NetworkAddresses, func(i, j int) bool { + return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac + }) + sort.Slice(other.NetworkAddresses, func(i, j int) bool { + return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac + }) equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool { return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP }) @@ -116,6 +121,12 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { return false } + sort.Slice(p.Files, func(i, j int) bool { + return p.Files[i].Path < p.Files[j].Path + }) + sort.Slice(other.Files, func(i, j int) bool { + return other.Files[i].Path < other.Files[j].Path + }) equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool { return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning }) @@ -172,24 +183,22 @@ func (p *Peer) Copy() *Peer { peerStatus = p.Status.Copy() } return &Peer{ - ID: p.ID, - AccountID: p.AccountID, - Key: p.Key, - SetupKey: p.SetupKey, - IP: p.IP, - Meta: p.Meta, - Name: p.Name, - DNSLabel: p.DNSLabel, - Status: peerStatus, - UserID: p.UserID, - SSHKey: p.SSHKey, - SSHEnabled: p.SSHEnabled, - LoginExpirationEnabled: p.LoginExpirationEnabled, - LastLogin: p.LastLogin, - CreatedAt: p.CreatedAt, - Ephemeral: p.Ephemeral, - Location: p.Location, - + ID: p.ID, + AccountID: p.AccountID, + Key: p.Key, + IP: p.IP, + Meta: p.Meta, + Name: p.Name, + DNSLabel: p.DNSLabel, + Status: peerStatus, + UserID: p.UserID, + SSHKey: p.SSHKey, + SSHEnabled: p.SSHEnabled, + LoginExpirationEnabled: p.LoginExpirationEnabled, + LastLogin: p.LastLogin, + CreatedAt: p.CreatedAt, + Ephemeral: p.Ephemeral, + Location: p.Location, InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } diff --git a/management/server/peer/peer_test.go b/management/server/peer/peer_test.go index 7b94f68c6..3d3a2e311 100644 --- a/management/server/peer/peer_test.go +++ b/management/server/peer/peer_test.go @@ -2,6 +2,7 @@ package peer import ( "fmt" + "net/netip" "testing" ) @@ -29,3 +30,56 @@ func BenchmarkFQDN(b *testing.B) { } }) } + +func TestIsEqual(t *testing.T) { + meta1 := PeerSystemMeta{ + NetworkAddresses: []NetworkAddress{{ + NetIP: netip.MustParsePrefix("192.168.1.2/24"), + Mac: "2", + }, + { + NetIP: netip.MustParsePrefix("192.168.1.0/24"), + Mac: "1", + }, + }, + Files: []File{ + { + Path: "/etc/hosts1", + Exist: true, + ProcessIsRunning: true, + }, + { + Path: "/etc/hosts2", + Exist: false, + ProcessIsRunning: false, + }, + }, + } + meta2 := PeerSystemMeta{ + NetworkAddresses: []NetworkAddress{ + { + NetIP: netip.MustParsePrefix("192.168.1.0/24"), + Mac: "1", + }, + { + NetIP: netip.MustParsePrefix("192.168.1.2/24"), + Mac: "2", + }, + }, + Files: []File{ + { + Path: "/etc/hosts2", + Exist: false, + ProcessIsRunning: false, + }, + { + Path: "/etc/hosts1", + Exist: true, + ProcessIsRunning: true, + }, + }, + } + if !meta1.isEqual(meta2) { + t.Error("meta1 should be equal to meta2") + } +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 0f1bb1888..c0ae4e178 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "io" "net" @@ -1090,7 +1092,6 @@ func Test_RegisterPeerByUser(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", @@ -1155,7 +1156,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "existingSetupKey", UserID: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ @@ -1175,7 +1175,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) - assert.Equal(t, peer.SetupKey, existingSetupKeyID) account, err := store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) @@ -1187,8 +1186,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed) - assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes) + + hashedKey := sha256.Sum256([]byte(existingSetupKeyID)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + assert.NotEqual(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed) + assert.Equal(t, 1, account.SetupKeys[encodedHashedKey].UsedTimes) } @@ -1221,7 +1223,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "existingSetupKey", UserID: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ @@ -1250,6 +1251,328 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) - assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) + + hashedKey := sha256.Sum256([]byte(faultyKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed.UTC()) + assert.Equal(t, 0, account.SetupKeys[encodedHashedKey].UsedTimes) +} + +func TestPeerAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) + require.NoError(t, err) + + err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + require.NoError(t, err) + + // create a user with auto groups + _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{ + { + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupA"}, + }, + { + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupB"}, + }, + { + Id: "regularUser3", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupC"}, + }, + }, true) + require.NoError(t, err) + + var peer4 *nbpeer.Peer + var peer5 *nbpeer.Peer + var peer6 *nbpeer.Peer + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update + t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Adding peer to unlinked group should not update account peers and not send peer update + t.Run("adding peer to unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting peer with unlinked group should not update account peers and not send peer update + t.Run("deleting peer with unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating peer label should update account peers and send peer update + t.Run("updating peer label", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + peer1.Name = "peer-1" + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with policy should update account peers and send peer update + t.Run("adding peer to group linked with policy", func(t *testing.T) { + err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }, false) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to policy should update account peers and send peer update + t.Run("deleting peer with linked group to policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with route should update account peers and send peer update + t.Run("adding peer to group linked with route", func(t *testing.T) { + route := nbroute.Route{ + ID: "testingRoute1", + Network: netip.MustParsePrefix("100.65.250.202/32"), + NetID: "superNet", + NetworkType: nbroute.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupB"}, + } + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to route should update account peers and send peer update + t.Run("deleting peer with linked group to route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer5.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with name server group should update account peers and send peer update + t.Run("adding peer to group linked with name server group", func(t *testing.T) { + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupC"}, + true, []string{}, true, userID, false, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to name server group should update account peers and send peer update + t.Run("deleting peer with linked group to route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer6.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) } diff --git a/management/server/policy.go b/management/server/policy.go index 63ac36cbf..3acef58b8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -7,9 +7,9 @@ import ( "strconv" "strings" + "github.com/netbirdio/netbird/management/proto" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -203,6 +203,17 @@ func (p *Policy) UpgradeAndFix() { } } +// ruleGroups returns a list of all groups referenced in the policy's rules, +// including sources and destinations. +func (p *Policy) ruleGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + groups = append(groups, rule.Destinations...) + } + return groups +} + // FirewallRule is a rule of the firewall. type FirewallRule struct { // PeerIP of the peer @@ -331,12 +342,12 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if !user.IsAdminOrServiceUser() { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewPoliciesError() } return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) @@ -349,12 +360,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if !user.IsAdminOrServiceUser() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if user.IsRegularUser() { + return status.NewUnauthorizedToViewPoliciesError() } groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) @@ -377,38 +388,38 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) } - action := activity.PolicyAdded - if isUpdate { - action = activity.PolicyUpdated - - if _, err = am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID); err != nil { - return err - } + updateAccountPeers, err := am.arePolicyChangesAffectPeers(ctx, policy, isUpdate) + if err != nil { + return err } err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return fmt.Errorf(errNetworkSerialIncrementFmt, err) + return fmt.Errorf("failed to increment network serial: %w", err) } - err = transaction.SavePolicy(ctx, LockingStrengthUpdate, policy) - if err != nil { + if err = transaction.SavePolicy(ctx, LockingStrengthUpdate, policy); err != nil { return fmt.Errorf("failed to save policy: %w", err) } - return nil }) if err != nil { return err } + action := activity.PolicyAdded + if isUpdate { + action = activity.PolicyUpdated + } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -421,7 +432,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return status.NewUserNotPartOfAccountError() } policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) @@ -430,13 +441,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po } err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) - if err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf(errNetworkSerialIncrementFmt, err) } - err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) - if err != nil { + if err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID); err != nil { return fmt.Errorf("failed to delete policy: %w", err) } return nil @@ -456,34 +465,49 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return nil } -// ListPolicies from the store +// ListPolicies from the store. func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewPoliciesError() } return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - result[i] = &proto.FirewallRule{ - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, +// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. +func (am *DefaultAccountManager) arePolicyChangesAffectPeers(ctx context.Context, policyToSave *Policy, isUpdate bool) (bool, error) { + if isUpdate { + existingPolicy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyToSave.AccountID, policyToSave.ID) + if err != nil { + return false, err } + + if !policyToSave.Enabled && !existingPolicy.Enabled { + return false, nil + } + + hasPeers, err := am.anyGroupHasPeers(ctx, policyToSave.AccountID, existingPolicy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return am.anyGroupHasPeers(ctx, policyToSave.AccountID, policyToSave.ruleGroups()) } - return result + + return am.anyGroupHasPeers(ctx, policyToSave.AccountID, policyToSave.ruleGroups()) } // getAllPeersFromGroups for given peer ID and list of groups @@ -597,3 +621,20 @@ func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { return validIDs } + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + result[i] = &proto.FirewallRule{ + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, + } + } + return result +} diff --git a/management/server/policy_test.go b/management/server/policy_test.go index bf9a53d16..5b1411702 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -5,7 +5,9 @@ import ( "fmt" "net" "testing" + "time" + "github.com/rs/xid" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -824,3 +826,375 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int { return 0 // a is equal to b } } + +func TestPolicyAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + { + ID: "groupD", + Name: "GroupD", + Peers: []string{peer1.ID, peer2.ID}, + }, + }) + assert.NoError(t, err) + + updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + }) + + // Saving policy with rule groups with no peers should not update account's peers and not send peer update + t.Run("saving policy with rule groups with no peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-rule-groups-no-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Saving policy with source group containing peers, but destination group without peers should + // update account's peers and send peer update + t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-has-peers-destination-none", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving policy with destination group containing peers, but source group without peers should + // update account's peers and send peer update + t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { + policy := Policy{ + ID: "policy-destination-has-peers-source-none", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: false, + Sources: []string{"groupC"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg2) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Disabling policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: false, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating disabled policy with destination and source groups containing peers should not update account's peers + // or send peer update + t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Description: "updated description", + Enabled: false, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Enabling policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged policy should trigger account peers update but not send peer update + t.Run("saving unchanged policy", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting policy should trigger account peers update and send peer update + t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { + policyID := "policy-source-destination-peers" + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + + }) + + // Deleting policy with destination group containing peers, but source group without peers should + // update account's peers and send peer update + t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { + policyID := "policy-destination-has-peers-source-none" + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg2) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting policy with no peers in groups should not update account's peers and not send peer update + t.Run("deleting policy with no peers in groups", func(t *testing.T) { + policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2 + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + +} diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 5aa12cd6d..d75b99ffa 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -12,22 +12,18 @@ import ( "golang.org/x/exp/maps" ) -const ( - errPostureAdminOnlyMsg = "only users with admin power are allowed to view posture checks" -) - func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, errPostureAdminOnlyMsg) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if !user.HasAdminPower() { + return nil, status.NewUnauthorizedToViewPostureChecksError() } return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) @@ -40,20 +36,24 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only admin users are allowed to update posture checks") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if !user.HasAdminPower() { + return status.NewUnauthorizedToViewPostureChecksError() } if err = am.validatePostureChecks(ctx, accountID, postureChecks); err != nil { return status.Errorf(status.InvalidArgument, err.Error()) } - action := activity.PostureCheckCreated + updateAccountPeers, err := am.arePostureCheckChangesAffectPeers(ctx, accountID, postureChecks.ID, isUpdate) + if err != nil { + return err + } + action := activity.PostureCheckCreated err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if isUpdate { action = activity.PostureCheckUpdated @@ -63,7 +63,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return fmt.Errorf(errNetworkSerialIncrementFmt, err) + return fmt.Errorf("failed to increment network serial: %w", err) } } @@ -78,11 +78,12 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if isUpdate { + if updateAccountPeers { account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return fmt.Errorf("failed to get account: %w", err) } + am.updateAccountPeers(ctx, account) } @@ -115,16 +116,12 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only admin users are allowed to delete posture checks") - } - if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return status.NewUserNotPartOfAccountError() } - if err = am.isPostureCheckLinkedToPolicy(ctx, postureChecksID, accountID); err != nil { - return err + if !user.HasAdminPower() { + return status.NewUnauthorizedToViewPostureChecksError() } postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) @@ -132,9 +129,13 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } + if err = am.isPostureCheckLinkedToPolicy(ctx, postureChecksID, accountID); err != nil { + return err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return fmt.Errorf(errNetworkSerialIncrementFmt, err) + return fmt.Errorf("failed to increment network serial: %w", err) } if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil { @@ -148,12 +149,6 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) - return nil } @@ -164,12 +159,12 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return nil, err } - if !user.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, errPostureAdminOnlyMsg) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if !user.HasAdminPower() { + return nil, status.NewUnauthorizedToViewPostureChecksError() } return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) @@ -250,3 +245,30 @@ func (am *DefaultAccountManager) isPeerInPolicySourceGroups(ctx context.Context, return false, nil } + +// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. +func (am *DefaultAccountManager) arePostureCheckChangesAffectPeers(ctx context.Context, accountID, postureCheckID string, exists bool) (bool, error) { + if !exists { + return false, nil + } + + policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureCheckID) { + hasPeers, err := am.anyGroupHasPeers(ctx, accountID, policy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + } + } + + return false, nil +} diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index b618fb20b..77ab1e8a7 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -3,7 +3,10 @@ package server import ( "context" "testing" + "time" + "github.com/netbirdio/netbird/management/server/group" + "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/posture" @@ -118,3 +121,458 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { return am.Store.GetAccount(context.Background(), account.Id) } + +func TestPostureCheckAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + postureCheck := posture.Checks{ + ID: "postureCheck", + Name: "postureCheck", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.28.0", + }, + }, + } + + // Saving unused posture check should not update account peers and not send peer update + t.Run("saving unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating unused posture check should not update account peers and not send peer update + t.Run("updating unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + policy := Policy{ + ID: "policyA", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + + // Linking posture check to policy should trigger update account peers and send peer update + t.Run("linking posture check to policy with peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating linked posture checks should update account peers and send peer update + t.Run("updating linked to posture check with peers", func(t *testing.T) { + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged posture check should not trigger account peers update and not send peer update + // since there is no change in the network map + t.Run("saving unchanged posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing posture check from policy should trigger account peers update and send peer update + t.Run("removing posture check from policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policy.SourcePostureChecks = []string{} + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting unused posture check should not trigger account peers update and not send peer update + t.Run("deleting unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update + t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating linked posture check to policy where destination has peers but source does not + // should trigger account peers update and send peer update + t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { + updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + }) + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating linked posture check to policy where source has peers but destination does not, + // should not trigger account peers update or send peer update + t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating linked client posture check to policy where source has peers but destination does not, + // should trigger account peers update and send peer update + t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + }, + }, + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} + +func TestArePostureCheckChangesAffectingPeers(t *testing.T) { + account := &Account{ + Policies: []*Policy{ + { + ID: "policyA", + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + }, + }, + SourcePostureChecks: []string{"checkA"}, + }, + }, + Groups: map[string]*group.Group{ + "groupA": { + ID: "groupA", + Peers: []string{"peer1"}, + }, + "groupB": { + ID: "groupB", + Peers: []string{}, + }, + }, + PostureChecks: []*posture.Checks{ + { + ID: "checkA", + }, + { + ID: "checkB", + }, + }, + } + + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + assert.False(t, result) + }) + + t.Run("posture check does not exist", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + assert.False(t, result) + }) + + t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"groupB"} + account.Policies[0].Rules[0].Destinations = []string{"groupA"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"groupA"} + account.Policies[0].Rules[0].Destinations = []string{"groupB"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} + account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.False(t, result) + }) + + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + account.Groups["groupA"].Peers = []string{} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.False(t, result) + }) +} diff --git a/management/server/route.go b/management/server/route.go index ce6ea79e2..9b5229092 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -52,12 +52,12 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, err } - if !user.IsAdminOrServiceUser() { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewRoutesError() } return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID)) @@ -181,12 +181,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) - } - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err + return nil, status.NewUserNotPartOfAccountError() } // Do not allow non-Linux peers @@ -274,6 +269,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri newRoute.KeepRoute = keepRoute newRoute.AccessControlGroups = accessControlGroupIDs + updateAccountPeers, err := am.areRouteChangesAffectPeers(ctx, &newRoute) + if err != nil { + return nil, err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf(errNetworkSerialIncrementFmt, err) @@ -292,11 +292,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) - account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, fmt.Errorf(errGetAccountFmt, err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, fmt.Errorf(errGetAccountFmt, err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return &newRoute, nil } @@ -309,7 +311,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return status.NewUserNotPartOfAccountError() } if routeToSave == nil { @@ -324,7 +326,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - _, err = am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeToSave.ID)) + oldRoute, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeToSave.ID)) if err != nil { return err } @@ -386,6 +388,16 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } + oldRouteAffectsPeers, err := am.areRouteChangesAffectPeers(ctx, oldRoute) + if err != nil { + return err + } + + newRouteAffectsPeers, err := am.areRouteChangesAffectPeers(ctx, routeToSave) + if err != nil { + return err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf(errNetworkSerialIncrementFmt, err) @@ -404,11 +416,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) + if oldRouteAffectsPeers || newRouteAffectsPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -421,7 +435,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return status.NewUserNotPartOfAccountError() } route, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID)) @@ -429,6 +443,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return err } + updateAccountPeers, err := am.areRouteChangesAffectPeers(ctx, route) + if err != nil { + return err + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf(errNetworkSerialIncrementFmt, err) @@ -442,11 +461,13 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) + if updateAccountPeers { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } + am.updateAccountPeers(ctx, account) } - am.updateAccountPeers(ctx, account) return nil } @@ -458,12 +479,12 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user return nil, err } - if !user.IsAdminOrServiceUser() { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewRoutesError() } return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) @@ -741,3 +762,22 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { } return &portInfo } + +// areRouteChangesAffectPeers checks if a given route affects peers by determining +// if it has a routing peer, distribution, or peer groups that include peers. +func (am *DefaultAccountManager) areRouteChangesAffectPeers(ctx context.Context, route *route.Route) (bool, error) { + if route.Peer != "" { + return true, nil + } + + hasPeers, err := am.anyGroupHasPeers(ctx, route.AccountID, route.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return am.anyGroupHasPeers(ctx, route.AccountID, route.PeerGroups) +} diff --git a/management/server/route_test.go b/management/server/route_test.go index 17eac951d..8bbff8d38 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "testing" + "time" "github.com/rs/xid" "github.com/stretchr/testify/assert" @@ -1777,3 +1778,281 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }) } + +func TestRouteAccountPeersUpdate(t *testing.T) { + manager, err := createRouterManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestRouteAccount(t, manager) + require.NoError(t, err, "failed to init testing account") + + err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) + }) + + // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update + t.Run("creating route no routing peer and no peers in groups", func(t *testing.T) { + route := route.Route{ + ID: "testingRoute1", + Network: netip.MustParsePrefix("100.65.250.202/32"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupA"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupA"}, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + + }) + + // Creating a route with no routing peer and having peers in groups should update account peers and send peer update + t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) { + route := route.Route{ + ID: "testingRoute2", + Network: netip.MustParsePrefix("192.0.2.0/32"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroup3}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup3}, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + + }) + + baseRoute := route.Route{ + ID: "testingRoute3", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + } + + // Creating route should update account peers and send peer update + t.Run("creating route with a routing peer", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newRoute, err := manager.CreateRoute( + context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, + baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, + baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, + ) + require.NoError(t, err) + baseRoute = *newRoute + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating the route should update account peers and send peer update when there is peers in group + t.Run("updating route", func(t *testing.T) { + baseRoute.Groups = []string{routeGroup1, routeGroup2} + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating unchanged route should update account peers and not send peer update + t.Run("updating unchanged route", func(t *testing.T) { + baseRoute.Groups = []string{routeGroup1, routeGroup2} + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting the route should update account peers and send peer update + t.Run("deleting route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to route peer groups that do not have any peers should update account peers and send peer update + t.Run("adding peer to route peer groups that do not have any peers", func(t *testing.T) { + newRoute := route.Route{ + Network: netip.MustParsePrefix("192.168.12.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to route groups that do not have any peers should update account peers and send peer update + t.Run("adding peer to route groups that do not have any peers", func(t *testing.T) { + newRoute := route.Route{ + Network: netip.MustParsePrefix("192.168.13.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupC"}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 956a97f1c..a3d39a217 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "hash/fnv" "strconv" "strings" @@ -74,6 +76,7 @@ type SetupKey struct { // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` Key string + KeySecret string Name string Type SetupKeyType CreatedAt time.Time @@ -105,6 +108,7 @@ func (key *SetupKey) Copy() *SetupKey { Id: key.Id, AccountID: key.AccountID, Key: key.Key, + KeySecret: key.KeySecret, Name: key.Name, Type: key.Type, CreatedAt: key.CreatedAt, @@ -121,19 +125,17 @@ func (key *SetupKey) Copy() *SetupKey { // EventMeta returns activity event meta related to the setup key func (key *SetupKey) EventMeta() map[string]any { - return map[string]any{"name": key.Name, "type": key.Type, "key": key.HiddenCopy(1).Key} + return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} } -// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix. +// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. // E.g., "831F6*******************************" -func (key *SetupKey) HiddenCopy(length int) *SetupKey { - k := key.Copy() - prefix := k.Key[0:5] - if length > utf8.RuneCountInString(key.Key) { - length = utf8.RuneCountInString(key.Key) - len(prefix) +func hiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) } - k.Key = prefix + strings.Repeat("*", length) - return k + return prefix + strings.Repeat("*", length) } // IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now @@ -156,6 +158,9 @@ func (key *SetupKey) IsRevoked() bool { // IsExpired if key was expired func (key *SetupKey) IsExpired() bool { + if key.ExpiresAt.IsZero() { + return false + } return time.Now().After(key.ExpiresAt) } @@ -170,30 +175,40 @@ func (key *SetupKey) IsOverUsed() bool { // GenerateSetupKey generates a new setup key func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, - usageLimit int, ephemeral bool) *SetupKey { + usageLimit int, ephemeral bool) (*SetupKey, string) { key := strings.ToUpper(uuid.New().String()) limit := usageLimit if t == SetupKeyOneOff { limit = 1 } + + expiresAt := time.Time{} + if validFor != 0 { + expiresAt = time.Now().UTC().Add(validFor) + } + + hashedKey := sha256.Sum256([]byte(key)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + return &SetupKey{ Id: strconv.Itoa(int(Hash(key))), - Key: key, + Key: encodedHashedKey, + KeySecret: hiddenKey(key, 4), Name: name, Type: t, CreatedAt: time.Now().UTC(), - ExpiresAt: time.Now().UTC().Add(validFor), + ExpiresAt: expiresAt, UpdatedAt: time.Now().UTC(), Revoked: false, UsedTimes: 0, AutoGroups: autoGroups, UsageLimit: limit, Ephemeral: ephemeral, - } + }, key } // GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration -func GenerateDefaultSetupKey() *SetupKey { +func GenerateDefaultSetupKey() (*SetupKey, string) { return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, SetupKeyUnlimitedUsage, false) } @@ -217,12 +232,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) - } - - keyDuration := DefaultSetupKeyDuration - if expiresIn != 0 { - keyDuration = expiresIn + return nil, status.NewUserNotPartOfAccountError() } groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) @@ -234,7 +244,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, err } - setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) + setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey); err != nil { @@ -257,6 +267,9 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } } + // for the creation return the plain key to the caller + setupKey.Key = plainKey + return setupKey, nil } @@ -275,7 +288,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return nil, status.NewUserNotPartOfAccountError() } groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) @@ -348,12 +361,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, err } - if !user.IsAdminOrServiceUser() { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewSetupKeysError() } setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) @@ -361,18 +374,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, err } - keys := make([]*SetupKey, 0, len(setupKeys)) - for _, key := range setupKeys { - var k *SetupKey - if !user.IsAdminOrServiceUser() { - k = key.HiddenCopy(999) - } else { - k = key.Copy() - } - keys = append(keys, k) - } - - return keys, nil + return setupKeys, nil } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. @@ -382,15 +384,15 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, err } - if !user.IsAdminOrServiceUser() { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") - } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + return nil, status.NewUserNotPartOfAccountError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + if user.IsRegularUser() { + return nil, status.NewUnauthorizedToViewSetupKeysError() + } + + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) if err != nil { return nil, err } @@ -400,11 +402,37 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use setupKey.UpdatedAt = setupKey.CreatedAt } - if !user.IsAdminOrServiceUser() { - setupKey = setupKey.HiddenCopy(999) + return setupKey, nil +} + +// DeleteSetupKey removes the setup key from the account +func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err } - return setupKey, nil + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewUnauthorizedToViewSetupKeysError() + } + + deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + if err != nil { + return err + } + + err = am.Store.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) + + return nil } func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index aa5075b02..2ed8aef95 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -2,13 +2,17 @@ package server import ( "context" + "crypto/sha256" + "encoding/base64" "fmt" "strconv" + "strings" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" @@ -65,7 +69,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, - key.Id, time.Now().UTC(), autoGroups) + key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) @@ -182,7 +186,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), - tCase.expectedUpdatedAt, tCase.expectedGroups) + tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) @@ -238,10 +242,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) { expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) var expectedAutoGroups []string - key := GenerateDefaultSetupKey() + key, plainKey := GenerateDefaultSetupKey() assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) + expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) } @@ -255,41 +259,41 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) + expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) } func TestSetupKey_IsValid(t *testing.T) { - validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) if !validKey.IsValid() { t.Errorf("expected key to be valid, got invalid %v", validKey) } // expired - expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) if expiredKey.IsValid() { t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) } // revoked - revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) revokedKey.Revoked = true if revokedKey.IsValid() { t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) } // overused - overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) overUsedKey.UsedTimes = 1 if overUsedKey.IsValid() { t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) } // overused - reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) reusableKey.UsedTimes = 99 if !reusableKey.IsValid() { t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) @@ -298,7 +302,7 @@ func TestSetupKey_IsValid(t *testing.T) { func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string, - expectedUpdatedAt time.Time, expectedAutoGroups []string) { + expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) { t.Helper() if key.Name != expectedName { t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name) @@ -328,13 +332,23 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt) } - _, err := uuid.Parse(key.Key) - if err != nil { - t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err) + if expectHashedKey { + if !isValidBase64SHA256(key.Key) { + t.Errorf("expected key to be hashed, got %v", key.Key) + } + } else { + _, err := uuid.Parse(key.Key) + if err != nil { + t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err) + } } - if key.Id != strconv.Itoa(int(Hash(key.Key))) { - t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id) + if !strings.HasSuffix(key.KeySecret, "****") { + t.Errorf("expected key secret to be secure, got %v", key.Key) + } + + if key.Id != expectedID { + t.Errorf("expected key Id %v, got %v", expectedID, key.Id) } if len(key.AutoGroups) != len(expectedAutoGroups) { @@ -343,12 +357,95 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal") } +func isValidBase64SHA256(encodedKey string) bool { + decoded, err := base64.StdEncoding.DecodeString(encodedKey) + if err != nil { + return false + } + + if len(decoded) != sha256.Size { + return false + } + + return true +} + func TestSetupKey_Copy(t *testing.T) { - key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, - key.UpdatedAt, key.AutoGroups) + key.UpdatedAt, key.AutoGroups, true) } + +func TestSetupKeyAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"group"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + var setupKey *SetupKey + + // Creating setup key should not update account peers and not send peer update + t.Run("creating setup key", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Saving setup key should not update account peers and not send peer update + t.Run("saving setup key", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveSetupKey(context.Background(), account.Id, setupKey, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 5b2d61b59..c68c182df 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -430,16 +430,6 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, return nil } -// DeleteHashedPAT2TokenIDIndex is noop in SqlStore -func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { - return nil -} - -// DeleteTokenID2UserIDIndex is noop in SqlStore -func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { - return nil -} - func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if err != nil { @@ -469,7 +459,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { var key SetupKey - result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey)) + result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -741,7 +731,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string - result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID) + result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -993,7 +983,7 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { var setupKey SetupKey result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&setupKey, keyQueryCondition, strings.ToUpper(key)) + First(&setupKey, keyQueryCondition, key) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") @@ -1543,6 +1533,21 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt return nil } +func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete setup key from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete setup key from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "setup key not found") + } + + return nil +} + // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { var nsGroups []*nbdns.NameServerGroup @@ -1597,7 +1602,7 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki } // GetPATByID retrieves a personal access token by its ID and user ID. -func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) { +func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) { var pat PersonalAccessToken result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). First(&pat, "id = ? AND user_id = ?", patID, userID) @@ -1612,6 +1617,18 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, return &pat, nil } +// GetUserPATs retrieves personal access tokens for a user. +func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) { + var pats []*PersonalAccessToken + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get user PAT's from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get user PAT's from store") + } + + return pats, nil +} + // SavePAT saves a personal access token to the database. func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error { result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 09be9ab85..a8e6576ed 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "math/rand" "net" @@ -71,7 +73,7 @@ func runLargeTest(t *testing.T, store Store) { if err != nil { t.Fatal(err) } - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { @@ -81,7 +83,6 @@ func runLargeTest(t *testing.T, store Store) { peer := &nbpeer.Peer{ ID: peerID, Key: peerID, - SetupKey: "", IP: netIP, Name: peerID, DNSLabel: peerID, @@ -133,7 +134,7 @@ func runLargeTest(t *testing.T, store Store) { } account.NameServerGroups[nameserver.ID] = nameserver - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey } @@ -215,30 +216,28 @@ func TestSqlite_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey = GenerateDefaultSetupKey() + setupKey, _ = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - SetupKey: "peerkeysetupkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account2) @@ -297,15 +296,14 @@ func TestSqlite_DeleteAccount(t *testing.T) { }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user @@ -394,13 +392,12 @@ func TestSqlite_SavePeer(t *testing.T) { // save status of non-existing peer peer := &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } ctx := context.Background() err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) @@ -453,13 +450,12 @@ func TestSqlite_SavePeerStatus(t *testing.T) { // save new status of existing peer account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -720,15 +716,14 @@ func newSqliteStore(t *testing.T) *SqlStore { func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ - Key: "peerkey" + str, - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey" + str, + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } return store.SaveAccount(context.Background(), account) @@ -760,30 +755,28 @@ func TestPostgresql_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey = GenerateDefaultSetupKey() + setupKey, _ = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - SetupKey: "peerkeysetupkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account2) @@ -842,15 +835,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user @@ -921,13 +913,12 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { // save new status of existing peer account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -1118,12 +1109,17 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) - assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key) + assert.Equal(t, encodedHashedKey, setupKey.Key) + assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret) assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) assert.Equal(t, "Default key", setupKey.Name) } @@ -1138,24 +1134,28 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 0, setupKey.UsedTimes) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) require.NoError(t, err) - setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 1, setupKey.UsedTimes) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) require.NoError(t, err) - setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } @@ -1264,3 +1264,32 @@ func TestSqlite_GetGroupByName(t *testing.T) { require.NoError(t, err) require.Equal(t, "All", group.Name) } + +func Test_DeleteSetupKeySuccessfully(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + + err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) + require.NoError(t, err) + + _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) + require.Error(t, err) +} + +func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + nonExistingKeyID := "non-existing-key-id" + + err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) + require.Error(t, err) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 29d185216..7a4ec3f67 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -101,16 +101,62 @@ func NewPeerLoginExpiredError() error { return Errorf(PermissionDenied, "peer login has expired, please log in once more") } -// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError(err error) error { - return Errorf(NotFound, "setup key not found: %s", err) -} - func NewGetAccountFromStoreError(err error) error { return Errorf(Internal, "issue getting account from store: %s", err) } +func NewUnauthorizedToViewAccountSettingError() error { + return Errorf(PermissionDenied, "only users with admin power can view account settings") +} + +// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account +func NewUserNotPartOfAccountError() error { + return Errorf(PermissionDenied, "user is not part of this account") +} + // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store func NewGetUserFromStoreError() error { return Errorf(Internal, "issue getting user from store") } + +// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key +func NewInvalidKeyIDError() error { + return Errorf(InvalidArgument, "invalid key ID") +} + +// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key +func NewSetupKeyNotFoundError(err error) error { + return Errorf(NotFound, "setup key not found: %s", err) +} + +// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key +func NewUnauthorizedToViewSetupKeysError() error { + return Errorf(PermissionDenied, "only users with admin power can view setup keys") +} + +func NewUnauthorizedToViewGroupsError() error { + return Errorf(PermissionDenied, "only users with admin power can view groups") +} +func NewUnauthorizedToViewPATsError() error { + return Errorf(PermissionDenied, "only users with admin power can view personal access tokens") +} + +func NewUnauthorizedToViewPoliciesError() error { + return Errorf(PermissionDenied, "only users with admin power can view policies") +} + +func NewUnauthorizedToViewPostureChecksError() error { + return Errorf(PermissionDenied, "only users with admin power can view posture checks") +} + +func NewUnauthorizedToViewDNSSettingsError() error { + return Errorf(PermissionDenied, "only users with admin power can view dns settings") +} + +func NewUnauthorizedToViewNSGroupsError() error { + return Errorf(PermissionDenied, "only users with admin power can view name server groups") +} + +func NewUnauthorizedToViewRoutesError() error { + return Errorf(PermissionDenied, "only users with admin power can view network routes") +} diff --git a/management/server/store.go b/management/server/store.go index e4b948be6..fda499e9d 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -69,8 +69,6 @@ type Store interface { SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) - DeleteHashedPAT2TokenIDIndex(hashedToken string) error - DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) @@ -111,6 +109,7 @@ type Store interface { GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) @@ -123,6 +122,7 @@ type Store interface { DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error) + GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error @@ -263,6 +263,9 @@ func getMigrations(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, + func(db *gorm.DB) error { + return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db) + }, } } diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 32a59128b..168973cad 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -26,8 +26,11 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); INSERT INTO installations VALUES(1,''); +INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); +INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 0188cef52..6fb96c971 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -2,9 +2,13 @@ package server import ( "context" + "fmt" + "runtime/debug" "sync" "time" + "github.com/netbirdio/netbird/management/server/differs" + "github.com/r3labs/diff/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" @@ -14,14 +18,17 @@ import ( const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + NetworkMap *NetworkMap } type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID peerChannels map[string]chan *UpdateMessage + // peerNetworkMaps is the UpdateMessage indexed by Peer.ID. + peerUpdateMessage map[string]*UpdateMessage // channelsMux keeps the mutex to access peerChannels - channelsMux *sync.Mutex + channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } @@ -29,9 +36,10 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - channelsMux: &sync.Mutex{}, - metrics: metrics, + peerChannels: make(map[string]chan *UpdateMessage), + peerUpdateMessage: make(map[string]*UpdateMessage), + channelsMux: &sync.RWMutex{}, + metrics: metrics, } } @@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda start := time.Now() var found, dropped bool + // skip sending sync update to the peer if there is no change in update message, + // it will not check on turn credential refresh as we do not send network map or client posture checks + if update.NetworkMap != nil { + updated := p.handlePeerMessageUpdate(ctx, peerID, update) + if !updated { + return + } + } + p.channelsMux.Lock() + defer func() { p.channelsMux.Unlock() if p.metrics != nil { @@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() + if update.NetworkMap != nil { + lastSentUpdate := p.peerUpdateMessage[peerID] + if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() { + log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update", + peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial()) + return + } + p.peerUpdateMessage[peerID] = update + } + if channel, ok := p.peerChannels[peerID]; ok { found = true select { @@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c closed = true delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) @@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) @@ -170,3 +200,72 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent. +func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool { + p.channelsMux.RLock() + lastSentUpdate := p.peerUpdateMessage[peerID] + p.channelsMux.RUnlock() + + if lastSentUpdate != nil { + updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update) + if err != nil { + log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) + return false + } + if !updated { + log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) + return false + } + } + + return true +} + +// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. +func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) { + defer func() { + if r := recover(); r != nil { + log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack()) + isNew, err = true, nil + } + }() + + if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() { + return false, nil + } + + differ, err := diff.NewDiffer( + diff.CustomValueDiffers(&differs.NetIPAddr{}), + diff.CustomValueDiffers(&differs.NetIPPrefix{}), + ) + if err != nil { + return false, fmt.Errorf("failed to create differ: %v", err) + } + + lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks) + currFiles := getChecksFiles(currUpdateToSend.Update.Checks) + + changelog, err := differ.Diff(lastSentFiles, currFiles) + if err != nil { + return false, fmt.Errorf("failed to diff checks: %v", err) + } + if len(changelog) > 0 { + return true, nil + } + + changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) + if err != nil { + return false, fmt.Errorf("failed to diff network map: %v", err) + } + return len(changelog) > 0, nil +} + +// getChecksFiles returns a list of files from the given checks. +func getChecksFiles(checks []*proto.Checks) []string { + files := make([]string, 0, len(checks)) + for _, check := range checks { + files = append(files, check.GetFiles()...) + } + return files +} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 69f5b895c..52b715e95 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -2,10 +2,19 @@ package server import ( "context" + "net" + "net/netip" "testing" "time" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + nbroute "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/util" + "github.com/stretchr/testify/assert" ) // var peersUpdater *PeersUpdateManager @@ -77,3 +86,470 @@ func TestCloseChannel(t *testing.T) { t.Error("Error closing the channel") } } + +func TestHandlePeerMessageUpdate(t *testing.T) { + tests := []struct { + name string + peerID string + existingUpdate *UpdateMessage + newUpdate *UpdateMessage + expectedResult bool + }{ + { + name: "update message with turn credentials update", + peerID: "peer", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + WiretrusteeConfig: &proto.WiretrusteeConfig{}, + }, + }, + expectedResult: true, + }, + { + name: "update message for peer without existing update", + peerID: "peer1", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + expectedResult: true, + }, + { + name: "update message with no changes in update", + peerID: "peer2", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + expectedResult: false, + }, + { + name: "update message with changes in checks", + peerID: "peer3", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + Checks: []*proto.Checks{ + { + Files: []string{"/usr/bin/netbird"}, + }, + }, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + expectedResult: true, + }, + { + name: "update message with lower serial number", + peerID: "peer4", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPeersUpdateManager(nil) + ctx := context.Background() + + if tt.existingUpdate != nil { + p.peerUpdateMessage[tt.peerID] = tt.existingUpdate + } + + result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestIsNewPeerUpdateMessage(t *testing.T) { + t.Run("Unchanged value", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + }) + + t.Run("Unchanged value with serial incremented", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + }) + + t.Run("Updating routes network", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + + }) + + t.Run("Updating routes groups", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"} + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating network map peers", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newPeer := &nbpeer.Peer{ + IP: net.ParseIP("192.168.1.4"), + SSHEnabled: true, + Key: "peer4-key", + DNSLabel: "peer4", + SSHKey: "peer4-ssh-key", + } + newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating process check", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + + newUpdateMessage2 := createMockUpdateMessage(t) + newUpdateMessage2.Update.NetworkMap.Serial++ + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + + newUpdateMessage3 := createMockUpdateMessage(t) + newUpdateMessage3.Update.Checks = []*proto.Checks{} + newUpdateMessage3.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3) + assert.NoError(t, err) + assert.True(t, message) + + newUpdateMessage4 := createMockUpdateMessage(t) + check := &posture.Checks{ + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/local/netbird", + MacPath: "/usr/bin/netbird", + }, + }, + }, + }, + } + newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)} + newUpdateMessage4.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4) + assert.NoError(t, err) + assert.True(t, message) + + newUpdateMessage5 := createMockUpdateMessage(t) + check = &posture.Checks{ + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", + MacPath: "/usr/local/netbird", + }, + }, + }, + }, + } + newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)} + newUpdateMessage5.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating DNS configuration", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newDomain := "newexample.com" + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append( + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains, + newDomain, + ) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating peer IP", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating firewall rule", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443" + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Add new firewall rule", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newRule := &FirewallRule{ + PeerIP: "192.168.1.3", + Direction: firewallRuleDirectionOUT, + Action: string(PolicyTrafficActionDrop), + Protocol: string(PolicyRuleProtocolUDP), + Port: "53", + } + newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Removing nameserver", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating name server IP", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating custom DNS zone", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2" + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + +} + +func createMockUpdateMessage(t *testing.T) *UpdateMessage { + t.Helper() + + _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + if err != nil { + t.Fatal(err) + } + domainList, err := domain.FromStringList([]string{"example.com"}) + if err != nil { + t.Fatal(err) + } + + config := &Config{ + Signal: &Host{ + Proto: "https", + URI: "signal.uri", + Username: "", + Password: "", + }, + Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, + TURNConfig: &TURNConfig{ + Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, + }, + } + peer := &nbpeer.Peer{ + IP: net.ParseIP("192.168.1.1"), + SSHEnabled: true, + Key: "peer-key", + DNSLabel: "peer1", + SSHKey: "peer1-ssh-key", + } + + secretManager := NewTimeBasedAuthSecretsManager( + NewPeersUpdateManager(nil), + &TURNConfig{ + TimeBasedCredentials: false, + CredentialsTTL: util.Duration{ + Duration: defaultDuration, + }, + Secret: "secret", + Turns: []*Host{TurnTestHost}, + }, + &Relay{ + Addresses: []string{"localhost:0"}, + CredentialsTTL: util.Duration{Duration: time.Hour}, + Secret: "secret", + }, + ) + + networkMap := &NetworkMap{ + Network: &Network{Net: *ipNet, Serial: 1000}, + Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, + OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, + Routes: []*nbroute.Route{ + { + ID: "route1", + Network: netip.MustParsePrefix("10.0.0.0/24"), + KeepRoute: true, + NetID: "route1", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"test1", "test2"}, + }, + { + ID: "route2", + Domains: domainList, + KeepRoute: true, + NetID: "route2", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"test1", "test2"}, + }, + }, + DNSConfig: nbdns.Config{ + ServiceEnable: true, + NameServerGroups: []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + { + ID: "ns1", + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Groups: []string{"group1"}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + }, + CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, + }, + FirewallRules: []*FirewallRule{ + {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + }, + } + dnsName := "example.com" + checks := []*posture.Checks{ + { + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", + MacPath: "/usr/bin/netbird", + }, + }, + }, + }, + }, + } + dnsCache := &DNSConfigCache{} + + turnToken, err := secretManager.GenerateTurnToken() + if err != nil { + t.Fatal(err) + } + + relayToken, err := secretManager.GenerateRelayToken() + if err != nil { + t.Fatal(err) + } + + return &UpdateMessage{ + Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache), + NetworkMap: networkMap, + } +} diff --git a/management/server/user.go b/management/server/user.go index 4c43c63fe..ac42b600b 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" @@ -31,8 +32,6 @@ const ( UserIssuedAPI = "api" UserIssuedIntegration = "integration" - - errUserNotPartOfAccountMsg = "user is not part of this account" ) // StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown @@ -104,6 +103,11 @@ func (u *User) IsAdminOrServiceUser() bool { return u.HasAdminPower() || u.IsServiceUser } +// IsRegularUser checks if the user is a regular user. +func (u *User) IsRegularUser() bool { + return !u.HasAdminPower() && !u.IsServiceUser +} + // ToUserInfo converts a User object to a UserInfo object. func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups @@ -475,7 +479,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init } func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { - meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { return err } @@ -487,18 +491,30 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account } am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } -func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error { - peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, account.Id, targetUserID) +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) { + peers, err := account.FindUserPeers(targetUserID) if err != nil { - return err + return false, status.Errorf(status.Internal, "failed to find user peers") } - return am.deletePeers(ctx, account.Id, initiatorUserID, peers) + hadPeers := len(peers) > 0 + if !hadPeers { + return false, nil + } + + peerIDs := make([]string, 0, len(peers)) + for _, peer := range peers { + peerIDs = append(peerIDs, peer.ID) + } + + return hadPeers, am.deletePeers(ctx, account.Id, initiatorUserID, peers) } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. @@ -543,6 +559,9 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin // CreatePAT creates a new PAT for the given user func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if tokenName == "" { return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") } @@ -551,28 +570,35 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) - if err != nil { - return nil, err + targetUser, ok := account.Users[targetUserID] + if !ok { + return nil, status.Errorf(status.NotFound, "user not found") } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) || - executingUser.AccountID != accountID { + executingUser, ok := account.Users[initiatorUserID] + if !ok { + return nil, status.Errorf(status.NotFound, "user not found") + } + + if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") } - pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id, executingUser.Id) + pat, err := CreateNewPAT(tokenName, expiresIn, targetUserID, executingUser.Id) if err != nil { return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } - if err = am.Store.SavePAT(ctx, LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { - return nil, fmt.Errorf("failed to save PAT: %w", err) + targetUser.PATs[pat.ID] = &pat.PersonalAccessToken + + err = am.Store.SaveAccount(ctx, account) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to save account: %v", err) } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} @@ -583,7 +609,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string // DeletePAT deletes a specific PAT from a user func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { return err } @@ -593,24 +619,19 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string return err } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) || - executingUser.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") - } - - pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, targetUserID) + pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID) if err != nil { return err } - if err = am.Store.DeletePAT(ctx, LockingStrengthUpdate, tokenID, targetUserID); err != nil { - return fmt.Errorf("failed to delete PAT: %w", err) + if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { + return status.NewUnauthorizedToViewPATsError() } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) - return nil + return am.Store.DeletePAT(ctx, LockingStrengthUpdate, targetUserID, tokenID) } // GetPAT returns a specific PAT from a user @@ -620,8 +641,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, err } - if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { + return nil, status.NewUnauthorizedToViewPATsError() } return am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID) @@ -634,21 +659,15 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) - if err != nil { - return nil, err + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { + return nil, status.NewUnauthorizedToViewPATsError() } - pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG)) - for _, pat := range targetUser.PATsG { - pats = append(pats, pat.Copy()) - } - - return pats, nil + return am.Store.GetUserPATs(ctx, LockingStrengthShare, targetUserID) } // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. @@ -703,6 +722,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, updatedUsers := make([]*UserInfo, 0, len(updates)) var ( expiredPeers []*nbpeer.Peer + userIDs []string eventsToStore []func() ) @@ -711,6 +731,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } + userIDs = append(userIDs, update.Id) + oldUser := account.Users[update.Id] if oldUser == nil { if !addIfNotExists { @@ -774,7 +796,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, err } - if account.Settings.GroupsPropagationEnabled { + if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { am.updateAccountPeers(ctx, account) } @@ -1050,7 +1072,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if peer.Status.LoginExpired { continue } - peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) @@ -1070,7 +1091,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - return fmt.Errorf(errGetAccountFmt, err) + return fmt.Errorf("error getting account: %w", err) } am.updateAccountPeers(ctx, account) } @@ -1131,7 +1152,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return status.Errorf(status.PermissionDenied, "only users with admin power can delete users") } - var allErrors error + var ( + allErrors error + updateAccountPeers bool + ) deletedUsersMeta := make(map[string]map[string]any) for _, targetUserID := range targetUserIDs { @@ -1157,12 +1181,16 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err)) continue } + if hadPeers { + updateAccountPeers = true + } + delete(account.Users, targetUserID) deletedUsersMeta[targetUserID] = meta } @@ -1172,7 +1200,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return fmt.Errorf("failed to delete users: %w", err) } - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } for targetUserID, meta := range deletedUsersMeta { am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) @@ -1181,11 +1211,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return allErrors } -func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) { +func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) - return nil, err + return nil, false, err } if !isNil(am.idpManager) { @@ -1196,16 +1226,16 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) if err != nil { log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) - return nil, err + return nil, false, err } } else { log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) } } - err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) + hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) if err != nil { - return nil, err + return nil, false, err } u, err := account.FindUser(targetUserID) @@ -1218,7 +1248,7 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun tuCreatedAt = u.CreatedAt } - return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil + return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. @@ -1297,3 +1327,13 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa } return nil, false } + +// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. +func areUsersLinkedToPeers(account *Account, userIDs []string) bool { + for _, peer := range account.Peers { + if slices.Contains(userIDs, peer.UserID) { + return true + } + } + return false +} diff --git a/management/server/user_test.go b/management/server/user_test.go index 1a5704551..d4f560a54 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,9 +10,12 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" @@ -1264,3 +1267,165 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { }) } } + +func TestUserAccountPeersUpdate(t *testing.T) { + // account groups propagation is enabled + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + require.NoError(t, err) + + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Creating a new regular user should not update account peers and not send peer update + t.Run("creating new regular user with no groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleUser, + Issued: UserIssuedAPI, + }, true) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // updating user with no linked peers should not update account peers and not send peer update + t.Run("updating user with no linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleUser, + Issued: UserIssuedAPI, + }, false) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // deleting user with no linked peers should not update account peers and not send peer update + t.Run("deleting user with no linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser1") + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // create a user and add new peer with the user + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + }, true) + require.NoError(t, err) + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + // updating user with linked peers should update account peers and send peer update + t.Run("updating user with linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + }, false) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) + }) + + // deleting user with linked peers should update account peers and send peer update + t.Run("deleting user with linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, peer4UpdMsg) + close(done) + }() + + err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser2") + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/relay/client/client.go b/relay/client/client.go index 90bc3ac41..a82a75453 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -142,6 +142,7 @@ type Client struct { muInstanceURL sync.Mutex onDisconnectListener func() + onConnectedListener func() listenerMutex sync.Mutex } @@ -184,12 +185,15 @@ func (c *Client) Connect() error { return err } + c.log = c.log.WithField("relay", c.instanceURL.String()) + c.log.Infof("relay connection established") + c.serviceIsRunning = true c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) + go c.notifyConnected() - c.log.Infof("relay connection established") return nil } @@ -236,6 +240,12 @@ func (c *Client) SetOnDisconnectListener(fn func()) { c.onDisconnectListener = fn } +func (c *Client) SetOnConnectedListener(fn func()) { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + c.onConnectedListener = fn +} + // HasConns returns true if there are connections. func (c *Client) HasConns() bool { c.mu.Lock() @@ -243,6 +253,12 @@ func (c *Client) HasConns() bool { return len(c.conns) > 0 } +func (c *Client) Ready() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.serviceIsRunning +} + // Close closes the connection to the relay server and all connections to other peers. func (c *Client) Close() error { return c.close(true) @@ -361,9 +377,9 @@ func (c *Client) readLoop(relayConn net.Conn) { c.instanceURL = nil c.muInstanceURL.Unlock() - c.notifyDisconnected() c.wgReadLoop.Done() _ = c.close(false) + c.notifyDisconnected() } func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) { @@ -542,6 +558,16 @@ func (c *Client) notifyDisconnected() { go c.onDisconnectListener() } +func (c *Client) notifyConnected() { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + + if c.onConnectedListener == nil { + return + } + go c.onConnectedListener() +} + func (c *Client) writeCloseMsg() { msg := messages.MarshalCloseMsg() _, err := c.relayConn.Write(msg) diff --git a/relay/client/guard.go b/relay/client/guard.go index f826cf1b6..d6b6b0da5 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -29,6 +29,10 @@ func NewGuard(context context.Context, relayClient *Client) *Guard { // OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent func (g *Guard) OnDisconnected() { + if g.quickReconnect() { + return + } + ticker := time.NewTicker(reconnectingTimeout) defer ticker.Stop() @@ -46,3 +50,19 @@ func (g *Guard) OnDisconnected() { } } } + +func (g *Guard) quickReconnect() bool { + ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond) + defer cancel() + <-ctx.Done() + + if g.ctx.Err() != nil { + return false + } + + if err := g.relayClient.Connect(); err != nil { + log.Errorf("failed to reconnect to relay server: %s", err) + return false + } + return true +} diff --git a/relay/client/manager.go b/relay/client/manager.go index 4554c7c0f..3981415fc 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -65,6 +65,7 @@ type Manager struct { relayClientsMutex sync.RWMutex onDisconnectedListeners map[string]*list.List + onReconnectedListenerFn func() listenerLock sync.Mutex } @@ -101,6 +102,7 @@ func (m *Manager) Serve() error { m.relayClient = client m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnConnectedListener(m.onServerConnected) m.relayClient.SetOnDisconnectListener(func() { m.onServerDisconnected(client.connectionURL) }) @@ -138,6 +140,18 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { return netConn, err } +// Ready returns true if the home Relay client is connected to the relay server. +func (m *Manager) Ready() bool { + if m.relayClient == nil { + return false + } + return m.relayClient.Ready() +} + +func (m *Manager) SetOnReconnectedListener(f func()) { + m.onReconnectedListenerFn = f +} + // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { @@ -240,6 +254,13 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { return conn, nil } +func (m *Manager) onServerConnected() { + if m.onReconnectedListenerFn == nil { + return + } + go m.onReconnectedListenerFn() +} + func (m *Manager) onServerDisconnected(serverAddress string) { if serverAddress == m.relayClient.connectionURL { go m.reconnectGuard.OnDisconnected() diff --git a/signal/client/client.go b/signal/client/client.go index ced3fb7d0..eff1ccb87 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -35,6 +35,7 @@ type Client interface { WaitStreamConnected() SendToStream(msg *proto.EncryptedMessage) error Send(msg *proto.Message) error + SetOnReconnectedListener(func()) } // UnMarshalCredential parses the credentials from the message and returns a Credential instance diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7a3b502ff..2ff84e460 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -43,6 +43,8 @@ type GrpcClient struct { connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex + + onReconnectedListenerFn func() } func (c *GrpcClient) StreamConnected() bool { @@ -181,12 +183,17 @@ func (c *GrpcClient) notifyStreamDisconnected() { func (c *GrpcClient) notifyStreamConnected() { c.mux.Lock() defer c.mux.Unlock() + c.status = StreamConnected if c.connectedCh != nil { // there are goroutines waiting on this channel -> release them close(c.connectedCh) c.connectedCh = nil } + + if c.onReconnectedListenerFn != nil { + c.onReconnectedListenerFn() + } } func (c *GrpcClient) getStreamStatusChan() <-chan struct{} { @@ -271,6 +278,13 @@ func (c *GrpcClient) WaitStreamConnected() { } } +func (c *GrpcClient) SetOnReconnectedListener(fn func()) { + c.mux.Lock() + defer c.mux.Unlock() + + c.onReconnectedListenerFn = fn +} + // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // GrpcClient.connWg can be used to wait diff --git a/signal/client/mock.go b/signal/client/mock.go index 70ecea9ed..32236c82c 100644 --- a/signal/client/mock.go +++ b/signal/client/mock.go @@ -7,14 +7,20 @@ import ( ) type MockClient struct { - CloseFunc func() error - GetStatusFunc func() Status - StreamConnectedFunc func() bool - ReadyFunc func() bool - WaitStreamConnectedFunc func() - ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error - SendToStreamFunc func(msg *proto.EncryptedMessage) error - SendFunc func(msg *proto.Message) error + CloseFunc func() error + GetStatusFunc func() Status + StreamConnectedFunc func() bool + ReadyFunc func() bool + WaitStreamConnectedFunc func() + ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error + SendToStreamFunc func(msg *proto.EncryptedMessage) error + SendFunc func(msg *proto.Message) error + SetOnReconnectedListenerFunc func(f func()) +} + +// SetOnReconnectedListener sets the function to be called when the client reconnects. +func (sm *MockClient) SetOnReconnectedListener(_ func()) { + // Do nothing } func (sm *MockClient) IsHealthy() bool { diff --git a/util/file_suite_test.go b/util/file_suite_test.go deleted file mode 100644 index 3de7db49b..000000000 --- a/util/file_suite_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package util_test - -import ( - "crypto/md5" - "encoding/hex" - "io" - "os" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "github.com/netbirdio/netbird/util" -) - -var _ = Describe("Client", func() { - - var ( - tmpDir string - ) - - type TestConfig struct { - SomeMap map[string]string - SomeArray []string - SomeField int - } - - BeforeEach(func() { - var err error - tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - }) - - AfterEach(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) - - Describe("Config", func() { - Context("in JSON format", func() { - It("should be written and read successfully", func() { - - m := make(map[string]string) - m["key1"] = "value1" - m["key2"] = "value2" - - arr := []string{"value1", "value2"} - - written := &TestConfig{ - SomeMap: m, - SomeArray: arr, - SomeField: 99, - } - - err := util.WriteJson(tmpDir+"/testconfig.json", written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) - Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) - Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) - Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) - - }) - }) - }) - - Describe("Copying file contents", func() { - Context("from one file to another", func() { - It("should be successful", func() { - - src := tmpDir + "/copytest_src" - dst := tmpDir + "/copytest_dst" - - err := util.WriteJson(src, []string{"1", "2", "3"}) - Expect(err).NotTo(HaveOccurred()) - - err = util.CopyFileContents(src, dst) - Expect(err).NotTo(HaveOccurred()) - - hashSrc := md5.New() - hashDst := md5.New() - - srcFile, err := os.Open(src) - Expect(err).NotTo(HaveOccurred()) - - dstFile, err := os.Open(dst) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashSrc, srcFile) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashDst, dstFile) - Expect(err).NotTo(HaveOccurred()) - - err = srcFile.Close() - Expect(err).NotTo(HaveOccurred()) - - err = dstFile.Close() - Expect(err).NotTo(HaveOccurred()) - - Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) - }) - }) - }) - - Describe("Handle config file without full path", func() { - Context("config file handling", func() { - It("should be successful", func() { - written := &TestConfig{ - SomeField: 123, - } - cfgFile := "test_cfg.json" - defer os.Remove(cfgFile) - - err := util.WriteJson(cfgFile, written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(cfgFile, &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - }) - }) - }) -}) diff --git a/util/file_test.go b/util/file_test.go index 1330e738e..566d8eda6 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,12 +1,142 @@ package util import ( + "crypto/md5" + "encoding/hex" + "io" "os" "reflect" "strings" "testing" + + "github.com/stretchr/testify/require" ) +type TestConfig struct { + SomeMap map[string]string + SomeArray []string + SomeField int +} + +func TestConfigJSON(t *testing.T) { + tests := []struct { + name string + config *TestConfig + expectedError bool + }{ + { + name: "Valid JSON config", + config: &TestConfig{ + SomeMap: map[string]string{"key1": "value1", "key2": "value2"}, + SomeArray: []string{"value1", "value2"}, + SomeField: 99, + }, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + + err := WriteJson(tmpDir+"/testconfig.json", tt.config) + require.NoError(t, err) + + read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) + require.NoError(t, err) + require.NotNil(t, read) + require.Equal(t, tt.config.SomeMap["key1"], read.(*TestConfig).SomeMap["key1"]) + require.Equal(t, tt.config.SomeMap["key2"], read.(*TestConfig).SomeMap["key2"]) + require.ElementsMatch(t, tt.config.SomeArray, read.(*TestConfig).SomeArray) + require.Equal(t, tt.config.SomeField, read.(*TestConfig).SomeField) + }) + } +} + +func TestCopyFileContents(t *testing.T) { + tests := []struct { + name string + srcContent []string + expectedError bool + }{ + { + name: "Copy file contents successfully", + srcContent: []string{"1", "2", "3"}, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + + src := tmpDir + "/copytest_src" + dst := tmpDir + "/copytest_dst" + + err := WriteJson(src, tt.srcContent) + require.NoError(t, err) + + err = CopyFileContents(src, dst) + require.NoError(t, err) + + hashSrc := md5.New() + hashDst := md5.New() + + srcFile, err := os.Open(src) + require.NoError(t, err) + defer func() { + _ = srcFile.Close() + }() + + dstFile, err := os.Open(dst) + require.NoError(t, err) + defer func() { + _ = dstFile.Close() + }() + + _, err = io.Copy(hashSrc, srcFile) + require.NoError(t, err) + + _, err = io.Copy(hashDst, dstFile) + require.NoError(t, err) + + require.Equal(t, hex.EncodeToString(hashSrc.Sum(nil)[:16]), hex.EncodeToString(hashDst.Sum(nil)[:16])) + }) + } +} + +func TestHandleConfigFileWithoutFullPath(t *testing.T) { + tests := []struct { + name string + config *TestConfig + expectedError bool + }{ + { + name: "Handle config file without full path", + config: &TestConfig{ + SomeField: 123, + }, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgFile := "test_cfg.json" + defer func() { + _ = os.Remove(cfgFile) + }() + + err := WriteJson(cfgFile, tt.config) + require.NoError(t, err) + + read, err := ReadJson(cfgFile, &TestConfig{}) + require.NoError(t, err) + require.NotNil(t, read) + }) + } +} + func TestReadJsonWithEnvSub(t *testing.T) { type Config struct { CertFile string `json:"CertFile"`