mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Compare commits
4 Commits
feature/an
...
snyk-fix-7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22df3afa78 | ||
|
|
17bab881f7 | ||
|
|
25ed58328a | ||
|
|
644ed4b934 |
@@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
if r.gpo {
|
||||
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
}
|
||||
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
@@ -401,6 +403,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||
}
|
||||
@@ -412,6 +415,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
@@ -198,6 +198,10 @@ type Engine struct {
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -341,6 +345,9 @@ func (e *Engine) Stop() error {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
// Stop WireGuard interface monitor and wait for it to exit
|
||||
e.wgIfaceMonitorWg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -479,6 +486,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.wgIfaceMonitorWg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer e.wgIfaceMonitorWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.restartEngine()
|
||||
} else if err != nil {
|
||||
log.Warnf("WireGuard interface monitor: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
98
client/internal/wg_iface_monitor.go
Normal file
98
client/internal/wg_iface_monitor.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||
// if the interface is deleted externally while the engine is running.
|
||||
type WGIfaceMonitor struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewWGIfaceMonitor creates a new WGIfaceMonitor instance.
|
||||
func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||
return &WGIfaceMonitor{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins monitoring the WireGuard interface.
|
||||
// It relies on the provided context cancellation to stop.
|
||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||
defer close(m.done)
|
||||
|
||||
// Skip on mobile platforms as they handle interface lifecycle differently
|
||||
if runtime.GOOS == "android" || runtime.GOOS == "ios" {
|
||||
log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS)
|
||||
return false, errors.New("not supported on mobile platforms")
|
||||
}
|
||||
|
||||
if ifaceName == "" {
|
||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||
return false, errors.New("empty interface name")
|
||||
}
|
||||
|
||||
// Get initial interface index to track the specific interface instance
|
||||
expectedIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName)
|
||||
return false, fmt.Errorf("interface %s not found: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// getInterfaceIndex returns the index of a network interface by name.
|
||||
// Returns an error if the interface is not found.
|
||||
func getInterfaceIndex(name string) (int, error) {
|
||||
if name == "" {
|
||||
return 0, fmt.Errorf("empty interface name")
|
||||
}
|
||||
ifi, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
// Check if it's specifically a "not found" error
|
||||
if errors.Is(err, &net.OpError{}) {
|
||||
// On some systems, this might be a "not found" error
|
||||
return 0, fmt.Errorf("interface not found: %w", err)
|
||||
}
|
||||
return 0, fmt.Errorf("failed to lookup interface: %w", err)
|
||||
}
|
||||
if ifi == nil {
|
||||
return 0, fmt.Errorf("interface not found")
|
||||
}
|
||||
return ifi.Index, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM ubuntu:24.04
|
||||
FROM ubuntu:24.10
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
|
||||
CMD ["--log-file", "console"]
|
||||
|
||||
@@ -20,29 +20,9 @@ import (
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
CustomZones sync.Map
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetCustomZone retrieves a cached custom zone
|
||||
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.CustomZones.Load(key); ok {
|
||||
return value.(*proto.CustomZone), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetCustomZone stores a custom zone in the cache
|
||||
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.CustomZones.Store(key, value)
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
@@ -212,14 +192,8 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
cacheKey := zone.Domain
|
||||
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
|
||||
} else {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
cache.SetCustomZone(cacheKey, protoZone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
|
||||
@@ -474,15 +474,6 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
// Verify that the cache contains elements from both configs
|
||||
if _, exists := cache.GetCustomZone("example.com"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.com")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetCustomZone("example.org"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.org")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
@@ -300,7 +300,6 @@ func (a *Account) GetPeerNetworkMap(
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
|
||||
Reference in New Issue
Block a user