diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index fdc2c3063..0d3f033fb 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 for i, domain := range domains { - policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) - if r.gpo { - policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) - } singleDomain := []string{domain} - if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err) + if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) + } + + if r.gpo { + if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure gpo DNS policy: %w", err) + } } log.Debugf("added NRPT entry for domain: %s", domain) @@ -401,6 +403,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err)) } @@ -412,6 +415,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err)) } diff --git a/client/internal/engine.go b/client/internal/engine.go index d4c465efb..828bc6e94 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -198,6 +198,10 @@ type Engine struct { latestSyncResponse *mgmProto.SyncResponse connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager + + // WireGuard interface monitor + wgIfaceMonitor *WGIfaceMonitor + wgIfaceMonitorWg sync.WaitGroup } // Peer is an instance of the Connection Peer @@ -341,6 +345,9 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } + // Stop WireGuard interface monitor and wait for it to exit + e.wgIfaceMonitorWg.Wait() + return nil } @@ -479,6 +486,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // starting network monitor at the very last to avoid disruptions e.startNetworkMonitor() + + // monitor WireGuard interface lifecycle and restart engine on changes + e.wgIfaceMonitor = NewWGIfaceMonitor() + e.wgIfaceMonitorWg.Add(1) + + go func() { + defer e.wgIfaceMonitorWg.Done() + + if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { + log.Infof("WireGuard interface monitor: %s, restarting engine", err) + e.restartEngine() + } else if err != nil { + log.Warnf("WireGuard interface monitor: %s", err) + } + }() + return nil } diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 95645329e..7bce6af80 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -908,7 +908,8 @@ func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { if iface, err := net.InterfaceByName(vpnIntf); err == nil { skipInterfaceIndex = iface.Index } else { - return nil, fmt.Errorf("get VPN interface %s: %w", vpnIntf, err) + // not critical, if we cannot get ahold of the interface then we won't need to skip it + log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err) } } diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go new file mode 100644 index 000000000..78d70c15b --- /dev/null +++ b/client/internal/wg_iface_monitor.go @@ -0,0 +1,98 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net" + "runtime" + "time" + + log "github.com/sirupsen/logrus" +) + +// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine +// if the interface is deleted externally while the engine is running. +type WGIfaceMonitor struct { + done chan struct{} +} + +// NewWGIfaceMonitor creates a new WGIfaceMonitor instance. +func NewWGIfaceMonitor() *WGIfaceMonitor { + return &WGIfaceMonitor{ + done: make(chan struct{}), + } +} + +// Start begins monitoring the WireGuard interface. +// It relies on the provided context cancellation to stop. +func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { + defer close(m.done) + + // Skip on mobile platforms as they handle interface lifecycle differently + if runtime.GOOS == "android" || runtime.GOOS == "ios" { + log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS) + return false, errors.New("not supported on mobile platforms") + } + + if ifaceName == "" { + log.Debugf("Interface monitor: empty interface name, skipping monitor") + return false, errors.New("empty interface name") + } + + // Get initial interface index to track the specific interface instance + expectedIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName) + return false, fmt.Errorf("interface %s not found: %w", ifaceName, err) + } + + log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) + case <-ticker.C: + currentIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + // Interface was deleted + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } + + // Check if interface index changed (interface was recreated) + if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + } + } + +} + +// getInterfaceIndex returns the index of a network interface by name. +// Returns an error if the interface is not found. +func getInterfaceIndex(name string) (int, error) { + if name == "" { + return 0, fmt.Errorf("empty interface name") + } + ifi, err := net.InterfaceByName(name) + if err != nil { + // Check if it's specifically a "not found" error + if errors.Is(err, &net.OpError{}) { + // On some systems, this might be a "not found" error + return 0, fmt.Errorf("interface not found: %w", err) + } + return 0, fmt.Errorf("failed to lookup interface: %w", err) + } + if ifi == nil { + return 0, fmt.Errorf("interface not found") + } + return ifi.Index, nil +} diff --git a/management/server/dns.go b/management/server/dns.go index f6f0201d3..6b73dbd0e 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -20,29 +20,9 @@ import ( // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { - CustomZones sync.Map NameServerGroups sync.Map } -// GetCustomZone retrieves a cached custom zone -func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) { - if c == nil { - return nil, false - } - if value, ok := c.CustomZones.Load(key); ok { - return value.(*proto.CustomZone), true - } - return nil, false -} - -// SetCustomZone stores a custom zone in the cache -func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) { - if c == nil { - return - } - c.CustomZones.Store(key, value) -} - // GetNameServerGroup retrieves a cached name server group func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { if c == nil { @@ -212,14 +192,8 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC } for _, zone := range update.CustomZones { - cacheKey := zone.Domain - if cachedZone, exists := cache.GetCustomZone(cacheKey); exists { - protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone) - } else { - protoZone := convertToProtoCustomZone(zone) - cache.SetCustomZone(cacheKey, protoZone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) } for _, nsGroup := range update.NameServerGroups { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index d58689544..55a1bbe66 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -474,15 +474,6 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Results should be different for different inputs") } - // Verify that the cache contains elements from both configs - if _, exists := cache.GetCustomZone("example.com"); !exists { - t.Errorf("Cache should contain custom zone for example.com") - } - - if _, exists := cache.GetCustomZone("example.org"); !exists { - t.Errorf("Cache should contain custom zone for example.org") - } - if _, exists := cache.GetNameServerGroup("group1"); !exists { t.Errorf("Cache should contain name server group 'group1'") } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 27d54e6c2..60a00207e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -258,6 +258,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } diff --git a/management/server/types/account.go b/management/server/types/account.go index ca075b9f6..a69d3bb08 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -300,7 +300,6 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone - if peersCustomZone.Domain != "" { records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) zones = append(zones, nbdns.CustomZone{ diff --git a/management/server/user.go b/management/server/user.go index 3c7c3f433..d40d33c6a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -965,6 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service + log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.BufferUpdateAccountPeers(ctx, accountID) }