diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 137e00d31..3ceb15eb6 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -7,10 +7,8 @@ import ( "runtime" "time" - "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" - nberrors "github.com/netbirdio/netbird/client/errors" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" @@ -36,6 +34,7 @@ const ( reasonRouteUpdate reasonPeerUpdate reasonShutdown + reasonHA ) type routerPeerStatus struct { @@ -141,11 +140,11 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { // // It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { - chosen := route.ID("") + var chosen route.ID chosenScore := float64(0) currScore := float64(0) - currID := route.ID("") + var currID route.ID if c.currentChosen != nil { currID = c.currentChosen.ID } @@ -254,45 +253,47 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { } } -func (c *clientNetwork) removeRouteFromWireGuardPeer() error { - if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil { +// addAllowedIPs adds the allowed IPs for the current chosen route to the handler. +func (c *clientNetwork) addAllowedIPs(route *route.Route) error { + if err := c.handler.AddAllowedIPs(route.Peer); err != nil { + return fmt.Errorf("add allowed IPs for peer %s: %w", route.Peer, err) + } + + if err := c.statusRecorder.AddPeerStateRoute(route.Peer, c.handler.String(), route.GetResourceID()); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } + + c.connectEvent(route) + return nil +} + +func (c *clientNetwork) removeAllowedIPs(route *route.Route, rsn reason) error { + if err := c.statusRecorder.RemovePeerStateRoute(route.Peer, c.handler.String()); err != nil { log.Warnf("Failed to update peer state: %v", err) } if err := c.handler.RemoveAllowedIPs(); err != nil { return fmt.Errorf("remove allowed IPs: %w", err) } + + c.disconnectEvent(route, rsn) + return nil } -func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error { - if c.currentChosen == nil { - return nil - } - - var merr *multierror.Error - - if err := c.removeRouteFromWireGuardPeer(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)) - } - if err := c.handler.RemoveRoute(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err)) - } - - c.disconnectEvent(rsn) - - return nberrors.FormatErrorOrNil(merr) -} - -func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error { +func (c *clientNetwork) recalculateRoutes(rsn reason) error { routerPeerStatuses := c.getRouterPeerStatuses() newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses) - // If no route is chosen, remove the route from the peer and system + // If no route is chosen, remove the route from the peer if newChosenID == "" { - if err := c.removeRouteFromPeerAndSystem(rsn); err != nil { - return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err) + if c.currentChosen == nil { + return nil + } + + if err := c.removeAllowedIPs(c.currentChosen, rsn); err != nil { + return fmt.Errorf("remove obsolete: %w", err) } c.currentChosen = nil @@ -306,38 +307,24 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error return nil } - var isNew bool - if c.currentChosen == nil { - // If they were not previously assigned to another peer, add routes to the system first - if err := c.handler.AddRoute(c.ctx); err != nil { - return fmt.Errorf("add route: %w", err) - } - isNew = true - } else { - // Otherwise, remove the allowed IPs from the previous peer first - if err := c.removeRouteFromWireGuardPeer(); err != nil { - return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) + // If the chosen route was assigned to a different peer, remove the allowed IPs first + if isNew := c.currentChosen == nil; !isNew { + if err := c.removeAllowedIPs(c.currentChosen, reasonHA); err != nil { + return fmt.Errorf("remove old: %w", err) } } - c.currentChosen = c.routes[newChosenID] - - if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil { - return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) + newChosenRoute := c.routes[newChosenID] + if err := c.addAllowedIPs(newChosenRoute); err != nil { + return fmt.Errorf("add new: %w", err) } - if isNew { - c.connectEvent() - } + c.currentChosen = newChosenRoute - err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID()) - if err != nil { - return fmt.Errorf("add peer state route: %w", err) - } return nil } -func (c *clientNetwork) connectEvent() { +func (c *clientNetwork) connectEvent(route *route.Route) { var defaultRoute bool for _, r := range c.routes { if r.Network.Bits() == 0 { @@ -353,9 +340,9 @@ func (c *clientNetwork) connectEvent() { meta := map[string]string{ "network": c.handler.String(), } - if c.currentChosen != nil { - meta["id"] = string(c.currentChosen.NetID) - meta["peer"] = c.currentChosen.Peer + if route != nil { + meta["id"] = string(route.NetID) + meta["peer"] = route.Peer } c.statusRecorder.PublishEvent( proto.SystemEvent_INFO, @@ -366,7 +353,7 @@ func (c *clientNetwork) connectEvent() { ) } -func (c *clientNetwork) disconnectEvent(rsn reason) { +func (c *clientNetwork) disconnectEvent(route *route.Route, rsn reason) { var defaultRoute bool for _, r := range c.routes { if r.Network.Bits() == 0 { @@ -384,9 +371,9 @@ func (c *clientNetwork) disconnectEvent(rsn reason) { var userMessage string meta := make(map[string]string) - if c.currentChosen != nil { - meta["id"] = string(c.currentChosen.NetID) - meta["peer"] = c.currentChosen.Peer + if route != nil { + meta["id"] = string(route.NetID) + meta["peer"] = route.Peer } meta["network"] = c.handler.String() switch rsn { @@ -401,6 +388,10 @@ func (c *clientNetwork) disconnectEvent(rsn reason) { severity = proto.SystemEvent_WARNING message = "Default route disconnected due to peer unreachability" userMessage = "Exit node connection lost. Your internet access might be affected." + case reasonHA: + severity = proto.SystemEvent_INFO + message = "Default route disconnected due to high availability change" + userMessage = "Exit node disconnected due to high availability change." default: severity = proto.SystemEvent_ERROR message = "Default route disconnected for unknown reasons" @@ -458,12 +449,12 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { select { case <-c.ctx.Done(): log.Debugf("Stopping watcher for network [%v]", c.handler) - if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil { + if err := c.removeAllowedIPs(c.currentChosen, reasonShutdown); err != nil { log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err) } return case <-c.peerStateUpdate: - err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate) + err := c.recalculateRoutes(reasonPeerUpdate) if err != nil { log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) } @@ -482,7 +473,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { if isTrueRouteUpdate { log.Debug("Client network update contains different routes, recalculating routes") - err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate) + err := c.recalculateRoutes(reasonRouteUpdate) if err != nil { log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) } diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go index 98c34dbee..ea8991c12 100644 --- a/client/internal/routemanager/static/route.go +++ b/client/internal/routemanager/static/route.go @@ -30,13 +30,17 @@ func (r *Route) String() string { } func (r *Route) AddRoute(context.Context) error { - _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}) - return err + if _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}); err != nil { + return err + } + return nil } func (r *Route) RemoveRoute() error { - _, err := r.routeRefCounter.Decrement(r.route.Network) - return err + if _, err := r.routeRefCounter.Decrement(r.route.Network); err != nil { + return err + } + return nil } func (r *Route) AddAllowedIPs(peerKey string) error { @@ -52,6 +56,8 @@ func (r *Route) AddAllowedIPs(peerKey string) error { } func (r *Route) RemoveAllowedIPs() error { - _, err := r.allowedIPsRefcounter.Decrement(r.route.Network) - return err + if _, err := r.allowedIPsRefcounter.Decrement(r.route.Network); err != nil { + return err + } + return nil }