Merge branch 'client-ipv6-dns' into client-ipv6-ssh-netflow

This commit is contained in:
Viktor Liu
2026-03-26 13:00:10 +01:00
8 changed files with 38 additions and 11 deletions

View File

@@ -85,6 +85,11 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// Mock implementation - no-op
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op

View File

@@ -57,6 +57,7 @@ type Server interface {
ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
SetRouteChecker(func(netip.Addr) bool)
}
type nsGroupsByDomain struct {
@@ -104,6 +105,7 @@ type DefaultServer struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
routeMatch func(netip.Addr) bool
probeMu sync.Mutex
probeCancel context.CancelFunc
@@ -229,6 +231,14 @@ func newDefaultServer(
return defaultServer
}
// SetRouteChecker sets the function used by upstream resolvers to determine
// whether an IP is routed through the tunnel.
func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) {
s.mux.Lock()
defer s.mux.Unlock()
s.routeMatch = f
}
// RegisterHandler registers a handler for the given domains with the given priority.
// Any previously registered handler for the same domain and priority will be replaced.
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -743,6 +753,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
return
}
handler.routeMatch = s.routeMatch
for _, ns := range originalNameservers {
if ns == config.ServerIP {
@@ -852,6 +863,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
if err != nil {
return nil, fmt.Errorf("create upstream resolver: %v", err)
}
handler.routeMatch = s.routeMatch
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
@@ -1036,6 +1048,7 @@ func (s *DefaultServer) addHostRootZone() {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
}
handler.routeMatch = s.routeMatch
handler.upstreamServers = maps.Keys(hostDNSServers)
handler.deactivate = func(error) {}

View File

@@ -70,6 +70,7 @@ type upstreamResolverBase struct {
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}
type upstreamFailure struct {

View File

@@ -69,12 +69,9 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} else {
upstreamIP = upstreamIP.Unmap()
}
// TODO: IsPrivate is a rough heuristic. It misses public IPs routed through
// the tunnel (e.g. 9.9.9.9 via network route) and incorrectly matches local
// LAN private IPs. Replace with a check against the active route table or
// the set of routed prefixes from the network map.
needsPrivate := u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() ||
(u.lNetV6.IsValid() && u.lNetV6.Contains(upstreamIP))
needsPrivate := u.lNet.Contains(upstreamIP) ||
u.lNetV6.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
if needsPrivate {
var bindIP netip.Addr
switch {
@@ -85,7 +82,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
if bindIP.IsValid() {
log.Debugf("using private client to query upstream: %s", upstream)
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(bindIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("create private client: %s", err)

View File

@@ -503,6 +503,17 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()