mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Merge branch 'main' into refactor/reducate-signaling
This commit is contained in:
@@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
if r.gpo {
|
||||
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
}
|
||||
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
@@ -401,6 +403,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||
}
|
||||
@@ -412,6 +415,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
@@ -198,6 +198,10 @@ type Engine struct {
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -341,6 +345,9 @@ func (e *Engine) Stop() error {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
// Stop WireGuard interface monitor and wait for it to exit
|
||||
e.wgIfaceMonitorWg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -479,6 +486,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.wgIfaceMonitorWg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer e.wgIfaceMonitorWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.restartEngine()
|
||||
} else if err != nil {
|
||||
log.Warnf("WireGuard interface monitor: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -908,7 +908,8 @@ func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
|
||||
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
|
||||
skipInterfaceIndex = iface.Index
|
||||
} else {
|
||||
return nil, fmt.Errorf("get VPN interface %s: %w", vpnIntf, err)
|
||||
// not critical, if we cannot get ahold of the interface then we won't need to skip it
|
||||
log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
98
client/internal/wg_iface_monitor.go
Normal file
98
client/internal/wg_iface_monitor.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||
// if the interface is deleted externally while the engine is running.
|
||||
type WGIfaceMonitor struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewWGIfaceMonitor creates a new WGIfaceMonitor instance.
|
||||
func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||
return &WGIfaceMonitor{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins monitoring the WireGuard interface.
|
||||
// It relies on the provided context cancellation to stop.
|
||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||
defer close(m.done)
|
||||
|
||||
// Skip on mobile platforms as they handle interface lifecycle differently
|
||||
if runtime.GOOS == "android" || runtime.GOOS == "ios" {
|
||||
log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS)
|
||||
return false, errors.New("not supported on mobile platforms")
|
||||
}
|
||||
|
||||
if ifaceName == "" {
|
||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||
return false, errors.New("empty interface name")
|
||||
}
|
||||
|
||||
// Get initial interface index to track the specific interface instance
|
||||
expectedIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName)
|
||||
return false, fmt.Errorf("interface %s not found: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// getInterfaceIndex returns the index of a network interface by name.
|
||||
// Returns an error if the interface is not found.
|
||||
func getInterfaceIndex(name string) (int, error) {
|
||||
if name == "" {
|
||||
return 0, fmt.Errorf("empty interface name")
|
||||
}
|
||||
ifi, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
// Check if it's specifically a "not found" error
|
||||
if errors.Is(err, &net.OpError{}) {
|
||||
// On some systems, this might be a "not found" error
|
||||
return 0, fmt.Errorf("interface not found: %w", err)
|
||||
}
|
||||
return 0, fmt.Errorf("failed to lookup interface: %w", err)
|
||||
}
|
||||
if ifi == nil {
|
||||
return 0, fmt.Errorf("interface not found")
|
||||
}
|
||||
return ifi.Index, nil
|
||||
}
|
||||
@@ -20,29 +20,9 @@ import (
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
CustomZones sync.Map
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetCustomZone retrieves a cached custom zone
|
||||
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.CustomZones.Load(key); ok {
|
||||
return value.(*proto.CustomZone), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetCustomZone stores a custom zone in the cache
|
||||
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.CustomZones.Store(key, value)
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
@@ -212,14 +192,8 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
cacheKey := zone.Domain
|
||||
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
|
||||
} else {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
cache.SetCustomZone(cacheKey, protoZone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
|
||||
@@ -474,15 +474,6 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
// Verify that the cache contains elements from both configs
|
||||
if _, exists := cache.GetCustomZone("example.com"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.com")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetCustomZone("example.org"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.org")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
@@ -258,6 +258,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -300,7 +300,6 @@ func (a *Account) GetPeerNetworkMap(
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
|
||||
@@ -965,6 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
|
||||
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user