diff --git a/client/android/client.go b/client/android/client.go index 3b8a5bd0f..79067398f 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -203,10 +203,6 @@ func (c *Client) Networks() *NetworkArray { continue } - if routes[0].IsDynamic() { - continue - } - peer, err := c.recorder.GetPeer(routes[0].Peer) if err != nil { log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/filter.go similarity index 97% rename from client/firewall/uspfilter/uspfilter.go rename to client/firewall/uspfilter/filter.go index dcff92c61..136d3741b 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/filter.go @@ -104,6 +104,11 @@ type Manager struct { flowLogger nftypes.FlowLogger blockRule firewall.Rule + + // Internal 1:1 DNAT + dnatEnabled atomic.Bool + dnatMappings map[netip.Addr]netip.Addr + dnatMutex sync.RWMutex } // decoder for packages @@ -189,6 +194,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe flowLogger: flowLogger, netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, + dnatMappings: make(map[netip.Addr]netip.Addr), } m.routingEnabled.Store(false) @@ -519,22 +525,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } -// AddDNATRule adds a DNAT rule -func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { - if m.nativeFirewall == nil { - return nil, errNatNotSupported - } - return m.nativeFirewall.AddDNATRule(rule) -} - -// DeleteDNATRule deletes a DNAT rule -func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { - if m.nativeFirewall == nil { - return errNatNotSupported - } - return m.nativeFirewall.DeleteDNATRule(rule) -} - // UpdateSet updates the rule destinations associated with the given set // by merging the existing prefixes with the new ones, then deduplicating. func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { @@ -608,6 +598,14 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool { return false } + translated := m.translateOutboundDNAT(packetData, d) + if translated { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error("Failed to re-decode packet after DNAT: %v", err) + return false + } + } + srcIP, dstIP := m.extractIPs(d) if !srcIP.IsValid() { m.logger.Error("Unknown network layer: %v", d.decoded[0]) @@ -618,7 +616,6 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool { return true } - // for netflow we keep track even if the firewall is stateless m.trackOutbound(d, srcIP, dstIP, size) return false @@ -747,9 +744,17 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool { return false } - // For all inbound traffic, first check if it matches a tracked connection. - // This must happen before any other filtering because the packets are statefully tracked. + // Step 1: Check connection tracking FIRST (with original addresses) if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) { + // Step 2: Apply reverse DNAT for established connections + translated := m.translateInboundReverse(packetData, d) + if translated { + // Re-decode after translation + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err) + return true + } + } return false } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/filter_test.go similarity index 100% rename from client/firewall/uspfilter/uspfilter_test.go rename to client/firewall/uspfilter/filter_test.go diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go new file mode 100644 index 000000000..ad1725d13 --- /dev/null +++ b/client/firewall/uspfilter/nat.go @@ -0,0 +1,309 @@ +package uspfilter + +import ( + "encoding/binary" + "fmt" + "net/netip" + + "github.com/google/gopacket/layers" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +func ipv4Checksum(header []byte) uint16 { + if len(header) < 20 { + return 0 + } + + var sum uint32 + for i := 0; i < len(header)-1; i += 2 { + sum += uint32(binary.BigEndian.Uint16(header[i : i+2])) + } + + if len(header)%2 == 1 { + sum += uint32(header[len(header)-1]) << 8 + } + + for (sum >> 16) > 0 { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + return ^uint16(sum) +} + +func icmpChecksum(data []byte) uint16 { + var sum uint32 + for i := 0; i < len(data)-1; i += 2 { + sum += uint32(binary.BigEndian.Uint16(data[i : i+2])) + } + + if len(data)%2 == 1 { + sum += uint32(data[len(data)-1]) << 8 + } + + for (sum >> 16) > 0 { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + return ^uint16(sum) +} + +func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { + if !originalAddr.IsValid() || !translatedAddr.IsValid() { + return fmt.Errorf("invalid IP addresses") + } + + if m.localipmanager.IsLocalIP(translatedAddr) { + return fmt.Errorf("cannot map to local IP: %s", translatedAddr) + } + + m.dnatMutex.Lock() + m.dnatMappings[originalAddr] = translatedAddr + if len(m.dnatMappings) == 1 { + m.dnatEnabled.Store(true) + } + m.dnatMutex.Unlock() + + return nil +} + +// RemoveInternalDNATMapping removes a 1:1 IP address mapping +func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { + m.dnatMutex.Lock() + defer m.dnatMutex.Unlock() + + if _, exists := m.dnatMappings[originalAddr]; !exists { + return fmt.Errorf("mapping not found for: %s", originalAddr) + } + + delete(m.dnatMappings, originalAddr) + if len(m.dnatMappings) == 0 { + m.dnatEnabled.Store(false) + } + + return nil +} + +// getDNATTranslation returns the translated address if a mapping exists +func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { + if !m.dnatEnabled.Load() { + return addr, false + } + + m.dnatMutex.RLock() + translated, exists := m.dnatMappings[addr] + m.dnatMutex.RUnlock() + return translated, exists +} + +// findReverseDNATMapping finds original address for return traffic +func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { + if !m.dnatEnabled.Load() { + return translatedAddr, false + } + + m.dnatMutex.RLock() + defer m.dnatMutex.RUnlock() + + for original, translated := range m.dnatMappings { + if translated == translatedAddr { + return original, true + } + } + + return translatedAddr, false +} + +// translateOutboundDNAT applies DNAT translation to outbound packets +func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + return false + } + + _, dstIP := m.extractIPs(d) + if !dstIP.IsValid() || !dstIP.Is4() { + return false + } + + translatedIP, exists := m.getDNATTranslation(dstIP) + if !exists { + return false + } + + if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { + m.logger.Error("Failed to rewrite packet destination: %v", err) + return false + } + + m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP) + return true +} + +// translateInboundReverse applies reverse DNAT to inbound return traffic +func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + return false + } + + srcIP, _ := m.extractIPs(d) + if !srcIP.IsValid() || !srcIP.Is4() { + return false + } + + originalIP, exists := m.findReverseDNATMapping(srcIP) + if !exists { + return false + } + + if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { + m.logger.Error("Failed to rewrite packet source: %v", err) + return false + } + + m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP) + return true +} + +// rewritePacketDestination replaces destination IP in the packet +func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { + if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { + return fmt.Errorf("only IPv4 supported") + } + + oldDst := make([]byte, 4) + copy(oldDst, packetData[16:20]) + newDst := newIP.AsSlice() + + copy(packetData[16:20], newDst) + + ipHeaderLen := int(d.ip4.IHL) * 4 + binary.BigEndian.PutUint16(packetData[10:12], 0) + ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) + binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) + + if len(d.decoded) > 1 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.updateTCPChecksum(packetData, ipHeaderLen, oldDst, newDst) + case layers.LayerTypeUDP: + m.updateUDPChecksum(packetData, ipHeaderLen, oldDst, newDst) + case layers.LayerTypeICMPv4: + m.updateICMPChecksum(packetData, ipHeaderLen) + } + } + + return nil +} + +// rewritePacketSource replaces the source IP address in the packet +func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { + if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { + return fmt.Errorf("only IPv4 supported") + } + + oldSrc := make([]byte, 4) + copy(oldSrc, packetData[12:16]) + newSrc := newIP.AsSlice() + + copy(packetData[12:16], newSrc) + + ipHeaderLen := int(d.ip4.IHL) * 4 + binary.BigEndian.PutUint16(packetData[10:12], 0) + ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) + binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) + + if len(d.decoded) > 1 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc, newSrc) + case layers.LayerTypeUDP: + m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc, newSrc) + case layers.LayerTypeICMPv4: + m.updateICMPChecksum(packetData, ipHeaderLen) + } + } + + return nil +} + +func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { + tcpStart := ipHeaderLen + if len(packetData) < tcpStart+18 { + return + } + + checksumOffset := tcpStart + 16 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) +} + +func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return + } + + checksumOffset := udpStart + 6 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + + if oldChecksum == 0 { + return + } + + newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) +} + +func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { + icmpStart := ipHeaderLen + if len(packetData) < icmpStart+8 { + return + } + + icmpData := packetData[icmpStart:] + binary.BigEndian.PutUint16(icmpData[2:4], 0) + checksum := icmpChecksum(icmpData) + binary.BigEndian.PutUint16(icmpData[2:4], checksum) +} + +// incrementalUpdate performs incremental checksum update per RFC 1624 +func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { + sum := uint32(^oldChecksum) + + for i := 0; i < len(oldBytes)-1; i += 2 { + sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2])) + } + if len(oldBytes)%2 == 1 { + sum += uint32(^oldBytes[len(oldBytes)-1]) << 8 + } + + for i := 0; i < len(newBytes)-1; i += 2 { + sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2])) + } + if len(newBytes)%2 == 1 { + sum += uint32(newBytes[len(newBytes)-1]) << 8 + } + + for (sum >> 16) > 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + + return ^uint16(sum) +} + +// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) +func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { + if m.nativeFirewall == nil { + return nil, errNatNotSupported + } + return m.nativeFirewall.AddDNATRule(rule) +} + +// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) +func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { + if m.nativeFirewall == nil { + return errNatNotSupported + } + return m.nativeFirewall.DeleteDNATRule(rule) +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 253ecb2a6..771b4f229 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -488,9 +488,9 @@ func (e *Engine) createFirewall() error { } func (e *Engine) initFirewall() error { - if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { + if err := e.routeManager.SetFirewall(e.firewall); err != nil { e.close() - return fmt.Errorf("enable server router: %w", err) + return fmt.Errorf("set firewall: %w", err) } if e.config.BlockLANAccess { diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index 46bff96db..0b8e161d2 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -10,11 +10,10 @@ import ( nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/iface" - "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/route" @@ -553,41 +552,16 @@ func (w *Watcher) Stop() { w.currentChosenStatus = nil } -func HandlerFromRoute( - rt *route.Route, - routeRefCounter *refcounter.RouteRefCounter, - allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, - dnsRouterInteval time.Duration, - statusRecorder *peer.Status, - wgInterface iface.WGIface, - dnsServer nbdns.Server, - peerStore *peerstore.Store, - useNewDNSRoute bool, -) RouteHandler { - switch handlerType(rt, useNewDNSRoute) { +func HandlerFromRoute(params common.HandlerParams) RouteHandler { + switch handlerType(params.Route, params.UseNewDNSRoute) { case handlerTypeDnsInterceptor: - return dnsinterceptor.New( - rt, - routeRefCounter, - allowedIPsRefCounter, - statusRecorder, - dnsServer, - wgInterface, - peerStore, - ) + return dnsinterceptor.New(params) case handlerTypeDynamic: - dns := nbdns.NewServiceViaMemory(wgInterface) - return dynamic.NewRoute( - rt, - routeRefCounter, - allowedIPsRefCounter, - dnsRouterInteval, - statusRecorder, - wgInterface, - fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), - ) + dns := nbdns.NewServiceViaMemory(params.WgInterface) + dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()) + return dynamic.NewRoute(params, dnsAddr) default: - return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) + return static.NewRoute(params) } } diff --git a/client/internal/routemanager/common/params.go b/client/internal/routemanager/common/params.go new file mode 100644 index 000000000..ed05a08c3 --- /dev/null +++ b/client/internal/routemanager/common/params.go @@ -0,0 +1,28 @@ +package common + +import ( + "time" + + "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/route" +) + +type HandlerParams struct { + Route *route.Route + RouteRefCounter *refcounter.RouteRefCounter + AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter + DnsRouterInteval time.Duration + StatusRecorder *peer.Status + WgInterface iface.WGIface + DnsServer dns.Server + PeerStore *peerstore.Store + UseNewDNSRoute bool + Firewall manager.Manager + FakeIPManager *fakeip.FakeIPManager +} diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 23478c88c..df0a18759 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "runtime" "strings" "sync" @@ -12,11 +13,14 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/common" + "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" @@ -24,6 +28,11 @@ import ( type domainMap map[domain.Domain][]netip.Prefix +type internalDNATer interface { + RemoveInternalDNATMapping(netip.Addr) error + AddInternalDNATMapping(netip.Addr, netip.Addr) error +} + type wgInterface interface { Name() string Address() wgaddr.Address @@ -40,26 +49,22 @@ type DnsInterceptor struct { interceptedDomains domainMap wgInterface wgInterface peerStore *peerstore.Store + firewall firewall.Manager + fakeIPManager *fakeip.FakeIPManager } -func New( - rt *route.Route, - routeRefCounter *refcounter.RouteRefCounter, - allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, - statusRecorder *peer.Status, - dnsServer nbdns.Server, - wgInterface wgInterface, - peerStore *peerstore.Store, -) *DnsInterceptor { +func New(params common.HandlerParams) *DnsInterceptor { return &DnsInterceptor{ - route: rt, - routeRefCounter: routeRefCounter, - allowedIPsRefcounter: allowedIPsRefCounter, - statusRecorder: statusRecorder, - dnsServer: dnsServer, - wgInterface: wgInterface, + route: params.Route, + routeRefCounter: params.RouteRefCounter, + allowedIPsRefcounter: params.AllowedIPsRefCounter, + statusRecorder: params.StatusRecorder, + dnsServer: params.DnsServer, + wgInterface: params.WgInterface, + peerStore: params.PeerStore, + firewall: params.Firewall, + fakeIPManager: params.FakeIPManager, interceptedDomains: make(domainMap), - peerStore: peerStore, } } @@ -78,9 +83,13 @@ func (d *DnsInterceptor) RemoveRoute() error { var merr *multierror.Error for domain, prefixes := range d.interceptedDomains { for _, prefix := range prefixes { - if _, err := d.routeRefCounter.Decrement(prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err)) + // Routes should use fake IPs + routePrefix := d.transformRealToFakePrefix(prefix) + if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err)) } + + // AllowedIPs should use real IPs if d.currentPeerKey != "" { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) @@ -88,8 +97,10 @@ func (d *DnsInterceptor) RemoveRoute() error { } } log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) - } + + d.cleanupDNATMappings() + for _, domain := range d.route.Domains { d.statusRecorder.DeleteResolvedDomainsStates(domain) } @@ -102,6 +113,68 @@ func (d *DnsInterceptor) RemoveRoute() error { return nberrors.FormatErrorOrNil(merr) } +// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled) +func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix { + if _, hasDNAT := d.internalDnatFw(); !hasDNAT { + return realPrefix + } + + if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok { + return netip.PrefixFrom(fakeIP, realPrefix.Bits()) + } + + return realPrefix +} + +// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs) +func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error { + // AllowedIPs always use real IPs + ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey) + if err != nil { + return fmt.Errorf("add allowed IP %s: %v", realPrefix, err) + } + + if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + realPrefix.Addr(), + domain.SafeString(), + ref.Out, + ) + } + + return nil +} + +// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix +func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error { + // Routes use fake IPs (so traffic to fake IPs gets routed to interface) + routePrefix := d.transformRealToFakePrefix(realPrefix) + if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil { + return fmt.Errorf("add route for IP %s: %v", routePrefix, err) + } + + // Add to AllowedIPs if we have a current peer (uses real IPs) + if d.currentPeerKey == "" { + return nil + } + + return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain) +} + +// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs) +func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error { + if d.currentPeerKey == "" { + return nil + } + + // AllowedIPs use real IPs + if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil { + return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err) + } + + return nil +} + func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { d.mu.Lock() defer d.mu.Unlock() @@ -109,14 +182,9 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { var merr *multierror.Error for domain, prefixes := range d.interceptedDomains { for _, prefix := range prefixes { - if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) - } else if ref.Count > 1 && ref.Out != peerKey { - log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", - prefix.Addr(), - domain.SafeString(), - ref.Out, - ) + // AllowedIPs use real IPs + if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil { + merr = multierror.Append(merr, err) } } } @@ -132,6 +200,7 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error { var merr *multierror.Error for _, prefixes := range d.interceptedDomains { for _, prefix := range prefixes { + // AllowedIPs use real IPs if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) } @@ -284,6 +353,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { log.Errorf("failed to update domain prefixes: %v", err) } + + d.replaceIPsInDNSResponse(r, newPrefixes) } } @@ -294,6 +365,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { return nil } +// logPrefixChanges handles the logging for prefix changes +func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) { + if len(toAdd) > 0 { + log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toAdd) + } + if len(toRemove) > 0 && !d.route.KeepRoute { + log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toRemove) + } +} + func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { d.mu.Lock() defer d.mu.Unlock() @@ -302,70 +389,184 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) var merr *multierror.Error + var dnatMappings map[netip.Addr]netip.Addr + + // Handle DNAT mappings for new prefixes + if _, hasDNAT := d.internalDnatFw(); hasDNAT { + dnatMappings = make(map[netip.Addr]netip.Addr) + for _, prefix := range toAdd { + realIP := prefix.Addr() + if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil { + dnatMappings[fakeIP] = realIP + log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP) + } else { + log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err) + } + } + } // Add new prefixes for _, prefix := range toAdd { - if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err)) - continue - } - - if d.currentPeerKey == "" { - continue - } - if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) - } else if ref.Count > 1 && ref.Out != d.currentPeerKey { - log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", - prefix.Addr(), - resolvedDomain.SafeString(), - ref.Out, - ) + if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil { + merr = multierror.Append(merr, err) } } + d.addDNATMappings(dnatMappings) + if !d.route.KeepRoute { // Remove old prefixes for _, prefix := range toRemove { - if _, err := d.routeRefCounter.Decrement(prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err)) + // Routes use fake IPs + routePrefix := d.transformRealToFakePrefix(prefix) + if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err)) } - if d.currentPeerKey != "" { - if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) - } + // AllowedIPs use real IPs + if err := d.removeAllowedIP(prefix); err != nil { + merr = multierror.Append(merr, err) } } + + d.removeDNATMappingsForRealIPs(toRemove) } - // Update domain prefixes using resolved domain as key + // Update domain prefixes using resolved domain as key - store real IPs if len(toAdd) > 0 || len(toRemove) > 0 { if d.route.KeepRoute { - // replace stored prefixes with old + added // nolint:gocritic newPrefixes = append(oldPrefixes, toAdd...) } d.interceptedDomains[resolvedDomain] = newPrefixes originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) + + // Store real IPs for status (user-facing), not fake IPs d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) - if len(toAdd) > 0 { - log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", - resolvedDomain.SafeString(), - originalDomain.SafeString(), - toAdd) - } - if len(toRemove) > 0 && !d.route.KeepRoute { - log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", - resolvedDomain.SafeString(), - originalDomain.SafeString(), - toRemove) - } + d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove) } return nberrors.FormatErrorOrNil(merr) } +// removeDNATMappingsForRealIPs removes DNAT mappings from the firewall for real IP prefixes +func (d *DnsInterceptor) removeDNATMappingsForRealIPs(realPrefixes []netip.Prefix) { + if len(realPrefixes) == 0 { + return + } + + dnatFirewall, ok := d.internalDnatFw() + if !ok { + return + } + + for _, prefix := range realPrefixes { + realIP := prefix.Addr() + if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { + if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil { + log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err) + } else { + log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP) + } + } + } +} + +// internalDnatFw checks if the firewall supports internal DNAT +func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) { + if d.firewall == nil || runtime.GOOS != "android" { + return nil, false + } + fw, ok := d.firewall.(internalDNATer) + return fw, ok +} + +// addDNATMappings adds DNAT mappings to the firewall +func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) { + if len(mappings) == 0 { + return + } + + dnatFirewall, ok := d.internalDnatFw() + if !ok { + return + } + + for fakeIP, realIP := range mappings { + if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil { + log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err) + } else { + log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP) + } + } +} + +// removeDNATMappings removes DNAT mappings from the firewall for removed prefixes +func (d *DnsInterceptor) removeDNATMappings(prefixes []netip.Prefix) { + if len(prefixes) == 0 { + return + } + + dnatFirewall, ok := d.internalDnatFw() + if !ok { + return + } + + for _, prefix := range prefixes { + fakeIP := prefix.Addr() + if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil { + log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err) + } else { + log.Debugf("Removed DNAT mapping for: %s", fakeIP) + } + } +} + +// cleanupDNATMappings removes all DNAT mappings for this interceptor +func (d *DnsInterceptor) cleanupDNATMappings() { + if _, ok := d.internalDnatFw(); !ok { + return + } + + for _, prefixes := range d.interceptedDomains { + d.removeDNATMappingsForRealIPs(prefixes) + } +} + +// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response +func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) { + if _, ok := d.internalDnatFw(); !ok { + return + } + + // Replace A and AAAA records with fake IPs + for _, answer := range reply.Answer { + switch rr := answer.(type) { + case *dns.A: + realIP, ok := netip.AddrFromSlice(rr.A) + if !ok { + continue + } + + if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { + rr.A = fakeIP.AsSlice() + log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) + } + + case *dns.AAAA: + realIP, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + continue + } + + if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { + rr.AAAA = fakeIP.AsSlice() + log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) + } + } + } +} + func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { prefixSet := make(map[netip.Prefix]bool) for _, prefix := range oldPrefixes { diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 47511d4af..b263e09ef 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -14,6 +14,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" @@ -52,24 +53,16 @@ type Route struct { resolverAddr string } -func NewRoute( - rt *route.Route, - routeRefCounter *refcounter.RouteRefCounter, - allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, - interval time.Duration, - statusRecorder *peer.Status, - wgInterface iface.WGIface, - resolverAddr string, -) *Route { +func NewRoute(params common.HandlerParams, resolverAddr string) *Route { return &Route{ - route: rt, - routeRefCounter: routeRefCounter, - allowedIPsRefcounter: allowedIPsRefCounter, - interval: interval, - dynamicDomains: domainMap{}, - statusRecorder: statusRecorder, - wgInterface: wgInterface, + route: params.Route, + routeRefCounter: params.RouteRefCounter, + allowedIPsRefcounter: params.AllowedIPsRefCounter, + interval: params.DnsRouterInteval, + statusRecorder: params.StatusRecorder, + wgInterface: params.WgInterface, resolverAddr: resolverAddr, + dynamicDomains: domainMap{}, } } diff --git a/client/internal/routemanager/fakeip/fakeip.go b/client/internal/routemanager/fakeip/fakeip.go new file mode 100644 index 000000000..14cf3c30c --- /dev/null +++ b/client/internal/routemanager/fakeip/fakeip.go @@ -0,0 +1,93 @@ +package fakeip + +import ( + "fmt" + "net/netip" + "sync" +) + +// FakeIPManager manages allocation of fake IPs from the 240.0.0.0/8 block +type FakeIPManager struct { + mu sync.Mutex + nextIP netip.Addr // Next IP to allocate + allocated map[netip.Addr]netip.Addr // real IP -> fake IP + fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP + baseIP netip.Addr // First usable IP: 240.0.0.1 + maxIP netip.Addr // Last usable IP: 240.255.255.254 +} + +// NewManager creates a new fake IP manager using 240.0.0.0/8 block +func NewManager() *FakeIPManager { + baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1}) + maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254}) + + return &FakeIPManager{ + nextIP: baseIP, + allocated: make(map[netip.Addr]netip.Addr), + fakeToReal: make(map[netip.Addr]netip.Addr), + baseIP: baseIP, + maxIP: maxIP, + } +} + +// AllocateFakeIP allocates a fake IP for the given real IP +// Returns the fake IP, or existing fake IP if already allocated +func (f *FakeIPManager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { + if !realIP.Is4() { + return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported") + } + + f.mu.Lock() + defer f.mu.Unlock() + + if fakeIP, exists := f.allocated[realIP]; exists { + return fakeIP, nil + } + + startIP := f.nextIP + for { + currentIP := f.nextIP + + // Advance to next IP, wrapping at boundary + if f.nextIP.Compare(f.maxIP) >= 0 { + f.nextIP = f.baseIP + } else { + f.nextIP = f.nextIP.Next() + } + + // Check if current IP is available + if _, inUse := f.fakeToReal[currentIP]; !inUse { + f.allocated[realIP] = currentIP + f.fakeToReal[currentIP] = realIP + return currentIP, nil + } + + // Prevent infinite loop if all IPs exhausted + if f.nextIP.Compare(startIP) == 0 { + return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block") + } + } +} + +// GetFakeIP returns the fake IP for a real IP if it exists +func (f *FakeIPManager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) { + f.mu.Lock() + defer f.mu.Unlock() + + fakeIP, exists := f.allocated[realIP] + return fakeIP, exists +} + +// GetRealIP returns the real IP for a fake IP if it exists, otherwise false +func (f *FakeIPManager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) { + f.mu.Lock() + defer f.mu.Unlock() + + realIP, exists := f.fakeToReal[fakeIP] + return realIP, exists +} + +// GetFakeIPBlock returns the fake IP block used by this manager +func (f *FakeIPManager) GetFakeIPBlock() netip.Prefix { + return netip.MustParsePrefix("240.0.0.0/8") +} diff --git a/client/internal/routemanager/fakeip/fakeip_test.go b/client/internal/routemanager/fakeip/fakeip_test.go new file mode 100644 index 000000000..d391cf2d0 --- /dev/null +++ b/client/internal/routemanager/fakeip/fakeip_test.go @@ -0,0 +1,242 @@ +package fakeip + +import ( + "net/netip" + "sync" + "testing" +) + +func TestNewManager(t *testing.T) { + manager := NewManager() + + if manager.baseIP.String() != "240.0.0.1" { + t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String()) + } + + if manager.maxIP.String() != "240.255.255.254" { + t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String()) + } + + if manager.nextIP.Compare(manager.baseIP) != 0 { + t.Errorf("Expected nextIP to start at baseIP") + } +} + +func TestAllocateFakeIP(t *testing.T) { + manager := NewManager() + realIP := netip.MustParseAddr("8.8.8.8") + + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IP: %v", err) + } + + if !fakeIP.Is4() { + t.Error("Fake IP should be IPv4") + } + + // Check it's in the correct range + if fakeIP.As4()[0] != 240 { + t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String()) + } + + // Should return same fake IP for same real IP + fakeIP2, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to get existing fake IP: %v", err) + } + + if fakeIP.Compare(fakeIP2) != 0 { + t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String()) + } +} + +func TestAllocateFakeIPIPv6Rejection(t *testing.T) { + manager := NewManager() + realIPv6 := netip.MustParseAddr("2001:db8::1") + + _, err := manager.AllocateFakeIP(realIPv6) + if err == nil { + t.Error("Expected error for IPv6 address") + } +} + +func TestGetFakeIP(t *testing.T) { + manager := NewManager() + realIP := netip.MustParseAddr("1.1.1.1") + + // Should not exist initially + _, exists := manager.GetFakeIP(realIP) + if exists { + t.Error("Fake IP should not exist before allocation") + } + + // Allocate and check + expectedFakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate: %v", err) + } + + fakeIP, exists := manager.GetFakeIP(realIP) + if !exists { + t.Error("Fake IP should exist after allocation") + } + + if fakeIP.Compare(expectedFakeIP) != 0 { + t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String()) + } +} + + + +func TestMultipleAllocations(t *testing.T) { + manager := NewManager() + + allocations := make(map[netip.Addr]netip.Addr) + + // Allocate multiple IPs + for i := 1; i <= 100; i++ { + realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err) + } + + // Check for duplicates + for _, existingFake := range allocations { + if fakeIP.Compare(existingFake) == 0 { + t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String()) + } + } + + allocations[realIP] = fakeIP + } + + // Verify all allocations can be retrieved + for realIP, expectedFake := range allocations { + actualFake, exists := manager.GetFakeIP(realIP) + if !exists { + t.Errorf("Missing allocation for %s", realIP.String()) + } + if actualFake.Compare(expectedFake) != 0 { + t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String()) + } + } +} + +func TestGetFakeIPBlock(t *testing.T) { + manager := NewManager() + block := manager.GetFakeIPBlock() + + expected := "240.0.0.0/8" + if block.String() != expected { + t.Errorf("Expected %s, got %s", expected, block.String()) + } +} + +func TestConcurrentAccess(t *testing.T) { + manager := NewManager() + + const numGoroutines = 50 + const allocationsPerGoroutine = 10 + + var wg sync.WaitGroup + results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine) + + // Concurrent allocations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + for j := 0; j < allocationsPerGoroutine; j++ { + realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)}) + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err) + return + } + results <- fakeIP + } + }(i) + } + + wg.Wait() + close(results) + + // Check for duplicates + seen := make(map[netip.Addr]bool) + count := 0 + for fakeIP := range results { + if seen[fakeIP] { + t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String()) + } + seen[fakeIP] = true + count++ + } + + if count != numGoroutines*allocationsPerGoroutine { + t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count) + } +} + +func TestIPExhaustion(t *testing.T) { + // Create a manager with limited range for testing + manager := &FakeIPManager{ + nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), + allocated: make(map[netip.Addr]netip.Addr), + fakeToReal: make(map[netip.Addr]netip.Addr), + baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), + maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available + } + + // Allocate all available IPs + realIPs := []netip.Addr{ + netip.MustParseAddr("1.0.0.1"), + netip.MustParseAddr("1.0.0.2"), + netip.MustParseAddr("1.0.0.3"), + } + + for _, realIP := range realIPs { + _, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IP: %v", err) + } + } + + // Try to allocate one more - should fail + _, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4")) + if err == nil { + t.Error("Expected exhaustion error") + } +} + +func TestWrapAround(t *testing.T) { + // Create manager starting near the end of range + manager := &FakeIPManager{ + nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), + allocated: make(map[netip.Addr]netip.Addr), + fakeToReal: make(map[netip.Addr]netip.Addr), + baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), + maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), + } + + // Allocate the last IP + fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1")) + if err != nil { + t.Fatalf("Failed to allocate first IP: %v", err) + } + + if fakeIP1.String() != "240.0.0.254" { + t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String()) + } + + // Next allocation should wrap around to the beginning + fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2")) + if err != nil { + t.Fatalf("Failed to allocate second IP: %v", err) + } + + if fakeIP2.String() != "240.0.0.1" { + t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String()) + } +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 919bf25e3..3319f90d0 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -24,6 +25,8 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" + "github.com/netbirdio/netbird/client/internal/routemanager/common" + "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -38,6 +41,10 @@ import ( "github.com/netbirdio/netbird/version" ) +type internalDNATer interface { + AddInternalDNATMapping(netip.Addr, netip.Addr) error +} + // Manager is a route manager interface type Manager interface { Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) @@ -49,7 +56,7 @@ type Manager interface { GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string - EnableServerRouter(firewall firewall.Manager) error + SetFirewall(firewall.Manager) error Stop(stateManager *statemanager.Manager) } @@ -89,11 +96,13 @@ type DefaultManager struct { // clientRoutes is the most recent list of clientRoutes received from the Management Service clientRoutes route.HAMap dnsServer dns.Server + firewall firewall.Manager peerStore *peerstore.Store useNewDNSRoute bool disableClientRoutes bool disableServerRoutes bool activeRoutes map[route.HAUniqueID]client.RouteHandler + fakeIPManager *fakeip.FakeIPManager } func NewManager(config ManagerConfig) *DefaultManager { @@ -129,6 +138,8 @@ func NewManager(config ManagerConfig) *DefaultManager { } if runtime.GOOS == "android" { + dm.fakeIPManager = fakeip.NewManager() + cr := dm.initialClientRoutes(config.InitialRoutes) dm.notifier.SetInitialClientRoutes(cr) } @@ -222,16 +233,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector { return routeselector.NewRouteSelector() } -func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { - if m.disableServerRoutes { +// SetFirewall sets the firewall manager for the DefaultManager +// Not thread-safe, should be called before starting the manager +func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { + m.firewall = firewall + + if m.disableServerRoutes || firewall == nil { log.Info("server routes are disabled") return nil } - if firewall == nil { - return errors.New("firewall manager is not set") - } - var err error m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) if err != nil { @@ -299,17 +310,20 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { } for id, route := range toAdd { - handler := client.HandlerFromRoute( - route, - m.routeRefCounter, - m.allowedIPsRefCounter, - m.dnsRouteInterval, - m.statusRecorder, - m.wgInterface, - m.dnsServer, - m.peerStore, - m.useNewDNSRoute, - ) + params := common.HandlerParams{ + Route: route, + RouteRefCounter: m.routeRefCounter, + AllowedIPsRefCounter: m.allowedIPsRefCounter, + DnsRouterInteval: m.dnsRouteInterval, + StatusRecorder: m.statusRecorder, + WgInterface: m.wgInterface, + DnsServer: m.dnsServer, + PeerStore: m.peerStore, + UseNewDNSRoute: m.useNewDNSRoute, + Firewall: m.firewall, + FakeIPManager: m.fakeIPManager, + } + handler := client.HandlerFromRoute(params) if err := handler.AddRoute(m.ctx); err != nil { merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err)) continue @@ -517,9 +531,27 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro for _, routes := range crMap { rs = append(rs, routes...) } + + fakeIPBlock := m.fakeIPManager.GetFakeIPBlock() + id := uuid.NewString() + fakeIPRoute := &route.Route{ + ID: route.ID(id), + Network: fakeIPBlock, + NetID: route.NetID(id), + Peer: m.pubKey, + NetworkType: route.IPv4Network, + } + rs = append(rs, fakeIPRoute) + return rs } +// supportsInternalDNAT checks if the firewall supports internal DNAT +func (m *DefaultManager) supportsInternalDNAT(fw firewall.Manager) bool { + _, ok := fw.(internalDNATer) + return ok +} + func isRouteSupported(route *route.Route) bool { if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() { return true diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 63bad689e..4e182f82c 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -15,7 +15,7 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) - UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error + UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap @@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList } -func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { +func (m *MockManager) SetFirewall(firewall.Manager) error { panic("implement me") } diff --git a/client/internal/routemanager/notifier/notifier.go b/client/internal/routemanager/notifier/notifier.go index 25a3a71e0..ebdd60323 100644 --- a/client/internal/routemanager/notifier/notifier.go +++ b/client/internal/routemanager/notifier/notifier.go @@ -32,10 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { nets := make([]string, 0) for _, r := range clientRoutes { - // filter out domain routes - if r.IsDynamic() { - continue - } nets = append(nets, r.Network.String()) } sort.Strings(nets) diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go index c8b9338e0..d480fdf00 100644 --- a/client/internal/routemanager/static/route.go +++ b/client/internal/routemanager/static/route.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/route" ) @@ -16,11 +17,11 @@ type Route struct { allowedIPsRefcounter *refcounter.AllowedIPsRefCounter } -func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route { +func NewRoute(params common.HandlerParams) *Route { return &Route{ - route: rt, - routeRefCounter: routeRefCounter, - allowedIPsRefcounter: allowedIPsRefCounter, + route: params.Route, + routeRefCounter: params.RouteRefCounter, + allowedIPsRefcounter: params.AllowedIPsRefCounter, } }