Merge branch 'main' into refactor/getaccount-raw

This commit is contained in:
crn4
2025-10-30 18:07:17 +01:00
79 changed files with 2122 additions and 425 deletions

View File

@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
@@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
}
return statusOutputString

View File

@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx)
resp, err := getStatus(ctx, false)
if err != nil {
return err
}
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}

View File

@@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = ruleInfo.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil

View File

@@ -151,14 +151,20 @@ type Manager interface {
DisableRouting() error
// AddDNATRule adds a DNAT rule
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
// DeleteDNATRule deletes the outbound DNAT rule.
DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -22,6 +22,8 @@ type BaseConnTrack struct {
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
DNATOrigPort atomic.Uint32
}
// these small methods will be inlined by the compiler

View File

@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
if exists {
t.updateState(key, conn, flags, direction, size)
return key, true
return key, uint16(conn.DNATOrigPort.Load()), true
}
return key, false
return key, 0, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
return origPort
}
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
return 0
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 {
return
}
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
t.logger.Trace2("New %s TCP connection: %s", direction, key)
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
}
}
// GetConnection safely retrieves a connection state
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.tickerCancel()

View File

@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
key := ConnKey{
SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())

View File

@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
if exists {
return origPort
}
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
return 0
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
return key, uint16(conn.DNATOrigPort.Load()), true
}
return key, false
return key, 0, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
SourcePort: srcPort,
DestPort: dstPort,
}
conn.DNATOrigPort.Store(uint32(origPort))
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace2("New %s UDP connection: %s", direction, key)
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}

View File

@@ -109,6 +109,10 @@ type Manager struct {
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
portDNATEnabled atomic.Bool
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex
}
// decoder for packages
@@ -122,6 +126,8 @@ type decoder struct {
icmp6 layers.ICMPv6
decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
}
// Create userspace firewall manager constructor
@@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
}
m.routingEnabled.Store(false)
@@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
m.trackOutbound(d, srcIP, dstIP, size)
m.trackOutbound(d, srcIP, dstIP, packetData, size)
m.translateOutboundDNAT(packetData, d)
return false
@@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
if origPort == 0 {
break
}
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite UDP port: %v", err)
}
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
if origPort == 0 {
break
}
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite TCP port: %v", err)
}
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
}
@@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
}
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
@@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false
}
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)

View File

@@ -50,6 +50,8 @@ type logMessage struct {
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
}
// Logger is a high-performance, non-blocking logger
@@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
select {
@@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
}
}
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
@@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
}
}
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
@@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
@@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
case 7:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
case 8:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
}
*buf = append(*buf, formatted...)
@@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error {
case <-done:
return nil
}
}
}

View File

@@ -5,7 +5,9 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -13,6 +15,21 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
const (
// Port offsets in TCP/UDP headers
sourcePortOffset = 0
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
@@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum)
}
// icmpChecksum calculates ICMP checksum.
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
@@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum)
}
// biDNATMap maintains bidirectional DNAT mappings.
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
// portDNATRule represents a port-specific DNAT rule.
type portDNATRule struct {
protocol gopacket.LayerType
origPort uint16
targetPort uint16
targetIP netip.Addr
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
@@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap {
}
}
// set adds a bidirectional DNAT mapping between original and translated addresses.
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
// delete removes a bidirectional DNAT mapping for the given original address.
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
@@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
}
}
// getTranslated returns the translated address for a given original address.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
// getOriginal returns the original address for a given translated address.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
if !originalAddr.IsValid() {
return fmt.Errorf("invalid original IP address")
}
if !translatedAddr.IsValid() {
return fmt.Errorf("invalid translated IP address")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
@@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
@@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
// RemoveInternalDNATMapping removes a 1:1 IP address mapping.
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
@@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
// getDNATTranslation returns the translated address if a mapping exists.
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
@@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
// findReverseDNATMapping finds original address for return traffic.
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
@@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
// translateOutboundDNAT applies DNAT translation to outbound packets.
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
@@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error1("Failed to rewrite packet destination: %v", err)
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
return false
}
@@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
// translateInboundReverse applies reverse DNAT to inbound return traffic.
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
@@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error1("Failed to rewrite packet source: %v", err)
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
return false
}
@@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
if !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
var oldIP [4]byte
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
newIPBytes := newIP.As4()
copy(packetData[16:20], newDst[:])
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
return errInvalidIPHeaderLength
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
@@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
return nil
}
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
@@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
@@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateICMPChecksum recalculates ICMP checksum after packet modification.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
@@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
// incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
@@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
@@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
// DeleteDNATRule deletes outbound DNAT rule.
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATRules = append(m.portDNATRules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATRules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {
return false
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort := uint16(d.tcp.DstPort)
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
case layers.LayerTypeUDP:
dstPort := uint16(d.udp.DstPort)
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
default:
return false
}
}
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATRules {
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
continue
}
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
return false
}
if rule.origPort != port {
continue
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort
return true
}
return false
}
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
portStart := tcpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}
portStart := udpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
checksumOffset := udpStart + 6
if len(packetData) >= udpStart+8 {
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum != 0 {
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
}
return nil
}

View File

@@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
}
})
}
// BenchmarkPortDNAT measures the performance of port DNAT operations
func BenchmarkPortDNAT(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
useMatchPort bool
description string
}{
{
name: "tcp_inbound_dnat_match",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: true,
description: "TCP inbound port DNAT translation (22 → 22022)",
},
{
name: "tcp_inbound_dnat_nomatch",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: false,
description: "TCP inbound with DNAT configured but no port match",
},
{
name: "tcp_inbound_no_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
useMatchPort: false,
description: "TCP inbound without DNAT (baseline)",
},
{
name: "udp_inbound_dnat_match",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: true,
description: "UDP inbound port DNAT translation (5353 → 22054)",
},
{
name: "udp_inbound_dnat_nomatch",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: false,
description: "UDP inbound with DNAT configured but no port match",
},
{
name: "udp_inbound_no_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
useMatchPort: false,
description: "UDP inbound without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
var origPort, targetPort, testPort uint16
if sc.proto == layers.IPProtocolTCP {
origPort, targetPort = 22, 22022
} else {
origPort, targetPort = 5353, 22054
}
if sc.useMatchPort {
testPort = origPort
} else {
testPort = 443 // Different port
}
// Setup port DNAT mapping if needed
if sc.setupDNAT {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
require.NoError(b, err)
}
// Pre-establish inbound connection for outbound reverse test
if sc.setupDNAT && sc.useMatchPort {
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
manager.filterInbound(inboundPacket, 0)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark inbound DNAT translation
b.Run("inbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
manager.filterInbound(packet, 0)
}
})
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
if sc.setupDNAT && sc.useMatchPort {
b.Run("outbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh return packet (from target port)
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
manager.filterOutbound(packet, 0)
}
})
}
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) {
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
testCases := []struct {
name string
protocol layers.IPProtocol
sourcePort uint16
targetPort uint16
}{
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
d := parsePacket(t, inboundPacket)
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
require.True(t, translated, "Inbound packet should be translated")
d = parsePacket(t, inboundPacket)
var dstPort uint16
switch tc.protocol {
case layers.IPProtocolTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.IPProtocolUDP:
dstPort = uint16(d.udp.DstPort)
}
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
})
}
}
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
d := parsePacket(t, packet)
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})
}
}
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
switch proto {
case layers.IPProtocolTCP:
return firewall.ProtocolTCP
case layers.IPProtocolUDP:
return firewall.ProtocolUDP
default:
return firewall.ProtocolALL
}
}

View File

@@ -16,25 +16,33 @@ type PacketStage int
const (
StageReceived PacketStage = iota
StageInboundPortDNAT
StageInbound1to1NAT
StageConntrack
StagePeerACL
StageRouting
StageRouteACL
StageForwarding
StageCompleted
StageOutbound1to1NAT
StageOutboundPortReverse
)
const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string {
return map[PacketStage]string{
StageReceived: "Received",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageReceived: "Received",
StageInboundPortDNAT: "Inbound Port DNAT",
StageInbound1to1NAT: "Inbound 1:1 NAT",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageOutbound1to1NAT: "Outbound 1:1 NAT",
StageOutboundPortReverse: "Outbound DNAT Reverse",
}[s]
}
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
return trace
}
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
}
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
m.handleOutboundDNAT(trace, packetData, d)
dropped := m.filterOutbound(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
}
return trace
}
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
trace.DestinationPort = m.getDestPort(d)
}
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
}
return false
}
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
return false
}
protocol := d.decoded[1]
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var originalPort uint16
if protocol == layers.LayerTypeTCP {
originalPort = uint16(d.tcp.DstPort)
} else {
originalPort = uint16(d.udp.DstPort)
}
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
if translated {
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
protoStr := "TCP"
if protocol == layers.LayerTypeUDP {
protoStr = "UDP"
}
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
trace.AddResult(StageInboundPortDNAT, msg, true)
return true
}
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
return false
}
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
trace.AddResult(StageInbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
m.traceOutbound1to1NAT(trace, packetData, d)
m.traceOutboundPortReverse(trace, packetData, d)
}
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatMappings[dstIP]
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
trace.AddResult(StageOutbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var origPort uint16
transport := d.decoded[1]
switch transport {
case layers.LayerTypeTCP:
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
case layers.LayerTypeUDP:
srcPort := uint16(d.udp.SrcPort)
dstPort := uint16(d.udp.DstPort)
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
default:
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
return false
}
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
return false
}
func (m *Manager) getDestPort(d *decoder) uint16 {
if len(d.decoded) < 2 {
return 0
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.DstPort)
default:
return 0
}
}

View File

@@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageCompleted,
@@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageCompleted,
},
@@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted,
},
expectedAllow: true,
@@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageRouting,
StagePeerACL,
StageCompleted,

View File

@@ -4,12 +4,15 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"runtime"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
@@ -17,6 +20,9 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots"
)
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}))
}
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
conn, err := grpc.NewClient(
addr,
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
}

View File

@@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil

View File

@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states.
state.json: Anonymized client state dump containing netbird states for the active profile.
mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information.
block.prof: Block profiling information.
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
log.Debugf("Adding state file from: %s", path)
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {

View File

@@ -13,6 +13,7 @@ import (
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
}
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var (
searchDomains []string
matchDomains []string
)
err = s.recordSystemDNSSettings(true)
if err != nil {
if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}
if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
if err := s.addLocalDNS(); err != nil {
log.Warnf("failed to add local DNS: %v", err)
}
s.updateState(stateManager)
}
for _, dConf := range config.Domains {
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
}
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else {
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil {
return fmt.Errorf("add match domains: %w", err)
}
s.updateState(stateManager)
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 {
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil {
return fmt.Errorf("add search domains: %w", err)
}
s.updateState(stateManager)
if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
return nil
}
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (s *systemConfigurator) string() string {
return "scutil"
}
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("Unable to get system DNS configuration")
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
}
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
log.Info("Not enabling local DNS server")
return nil
}
if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
}
return nil

View File

@@ -0,0 +1,111 @@
//go:build !ios
package dns
import (
"context"
"net/netip"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
defer func() {
require.NoError(t, sm.Stop(context.Background()))
}()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
config := HostDNSConfig{
ServerIP: netip.MustParseAddr("100.64.0.1"),
ServerPort: 53,
RouteAll: true,
Domains: []DomainConfig{
{Domain: "example.com", MatchOnly: true},
},
}
err := configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
require.NoError(t, sm.PersistState(context.Background()))
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
defer func() {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}()
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
t.Logf("Key %s exists before cleanup", key)
}
}
sm2 := statemanager.New(stateFile)
sm2.RegisterState(&ShutdownState{})
err = sm2.LoadState(&ShutdownState{})
require.NoError(t, err)
state := sm2.GetState(&ShutdownState{})
if state == nil {
t.Skip("State not saved, skipping cleanup test")
}
shutdownState, ok := state.(*ShutdownState)
require.True(t, ok)
err = shutdownState.Cleanup()
require.NoError(t, err)
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}
func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
output, err := cmd.CombinedOutput()
if err != nil {
if strings.Contains(string(output), "No such key") {
return false, nil
}
return false, err
}
return !strings.Contains(string(output), "No such key"), nil
}
func removeTestDNSKey(key string) error {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
_, err := cmd.CombinedOutput()
return err
}

View File

@@ -17,6 +17,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/winregistry"
)
var (
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
r.updateState(stateManager)
var searchDomains, matchDomains []string
for _, dConf := range config.Domains {
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
}
if err := r.removeDNSMatchPolicies(); err != nil {
log.Errorf("cleanup old dns match policies: %s", err)
}
if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil {
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
}
r.nrptEntryCount = count
} else {
if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err)
}
r.nrptEntryCount = 0
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
r.updateState(stateManager)
if err := r.updateSearchDomains(searchDomains); err != nil {
return fmt.Errorf("update search domains: %w", err)
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil
}
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err)
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
return fmt.Errorf("remove existing dns policy: %w", err)
}
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
}
defer closer(regKey)

View File

@@ -0,0 +1,102 @@
package dns
import (
"fmt"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/registry"
)
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config5, nil)
require.NoError(t, err)
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after first config", i)
}
config2 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config2, nil)
require.NoError(t, err)
// Verify first 2 entries exist
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after second config", i)
}
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
}
func registryKeyExists(path string) (bool, error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
if err == registry.ErrNotExist {
return false, nil
}
return false, err
}
k.Close()
return true, nil
}
func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10}
_ = cfg.removeDNSMatchPolicies()
}

View File

@@ -7,6 +7,7 @@ import (
)
type ShutdownState struct {
CreatedKeys []string
}
func (s *ShutdownState) Name() string {
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create host manager: %w", err)
}
for _, key := range s.CreatedKeys {
manager.createdKeys[key] = struct{}{}
}
if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err)
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"net"
"sync"
"net/netip"
"os"
"strconv"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
@@ -12,18 +14,14 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
listenPort uint16 = 5353
listenPortMu sync.RWMutex
)
const (
dnsTTL = 60 //seconds
dnsTTL = 60
envServerPort = "NB_DNS_FORWARDER_PORT"
)
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
@@ -36,28 +34,30 @@ type ForwarderEntry struct {
type Manager struct {
firewall firewall.Manager
statusRecorder *peer.Status
localAddr netip.Addr
serverPort uint16
fwRules []firewall.Rule
tcpRules []firewall.Rule
dnsForwarder *DNSForwarder
}
func ListenPort() uint16 {
listenPortMu.RLock()
defer listenPortMu.RUnlock()
return listenPort
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
serverPort := nbdns.ForwarderServerPort
if envPort := os.Getenv(envServerPort); envPort != "" {
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
serverPort = uint16(port)
log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort)
} else {
log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort)
}
}
func SetListenPort(port uint16) {
listenPortMu.Lock()
listenPort = port
listenPortMu.Unlock()
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
return &Manager{
firewall: fw,
statusRecorder: statusRecorder,
localAddr: localAddr,
serverPort: serverPort,
}
}
@@ -71,7 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
if m.localAddr.IsValid() && m.firewall != nil {
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
} else {
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
}
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
} else {
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
}
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists
@@ -96,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error {
}
var mErr *multierror.Error
if m.localAddr.IsValid() && m.firewall != nil {
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
}
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
}
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}
@@ -111,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error {
func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []uint16{ListenPort()},
Values: []uint16{m.serverPort},
}
if m.firewall == nil {

View File

@@ -203,8 +203,7 @@ type Engine struct {
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
// dns forwarder port
dnsFwdPort uint16
probeStunTurn *relay.StunTurnProbe
}
// Peer is an instance of the Connection Peer
@@ -247,7 +246,7 @@ func NewEngine(
statusRecorder: statusRecorder,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
dnsFwdPort: dnsfwd.ListenPort(),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}
sm := profilemanager.NewServiceManager("")
@@ -1060,10 +1059,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1084,7 +1087,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
@@ -1208,10 +1211,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
if forwarderPort == 0 {
forwarderPort = nbdns.ForwarderClientPort
}
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
ForwarderPort: forwarderPort,
}
for _, zone := range protoDNSConfig.GetCustomZones() {
@@ -1667,7 +1676,7 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool {
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy()
@@ -1699,8 +1708,12 @@ func (e *Engine) RunHealthProbes() bool {
}
e.syncMsgMux.Unlock()
results := e.probeICE(stuns, turns)
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true
@@ -1717,13 +1730,6 @@ func (e *Engine) RunHealthProbes() bool {
return allHealthy
}
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
)
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
e.syncMsgMux.Lock()
@@ -1843,16 +1849,11 @@ func (e *Engine) GetWgAddr() netip.Addr {
func (e *Engine) updateDNSForwarder(
enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry,
forwarderPort uint16,
) {
if e.config.DisableServerRoutes {
return
}
if forwarderPort > 0 {
dnsfwd.SetListenPort(forwarderPort)
}
if !enabled {
if e.dnsForwardMgr == nil {
return
@@ -1864,20 +1865,17 @@ func (e *Engine) updateDNSForwarder(
}
if len(fwdEntries) > 0 {
switch {
case e.dnsForwardMgr == nil:
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if e.dnsForwardMgr == nil {
localAddr := e.wgInterface.Address().IP
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
log.Infof("started domain router service with %d entries", len(fwdEntries))
case e.dnsFwdPort != forwarderPort:
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
e.restartDnsFwd(fwdEntries, forwarderPort)
e.dnsFwdPort = forwarderPort
default:
log.Infof("started domain router service with %d entries", len(fwdEntries))
} else {
e.dnsForwardMgr.UpdateDomains(fwdEntries)
}
} else if e.dnsForwardMgr != nil {
@@ -1887,20 +1885,6 @@ func (e *Engine) updateDNSForwarder(
}
e.dnsForwardMgr = nil
}
}
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
// stop and start the forwarder to apply the new port
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
}
func (e *Engine) GetNet() (*netstack.Net, error) {

View File

@@ -10,10 +10,10 @@ import (
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/dns"
)
type rcvChan chan *types.EventFields
@@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
if !l.dnsCollection.Load() && event.Protocol == types.UDP &&
(event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) {
return false
}

View File

@@ -2,6 +2,8 @@ package relay
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"net"
"sync"
@@ -15,6 +17,15 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
DefaultCacheTTL = 20 * time.Second
probeTimeout = 6 * time.Second
)
var (
ErrCheckInProgress = errors.New("probe check is already in progress")
)
// ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct {
URI string
@@ -22,8 +33,164 @@ type ProbeResult struct {
Addr string
}
type StunTurnProbe struct {
cacheResults []ProbeResult
cacheTimestamp time.Time
cacheKey string
cacheTTL time.Duration
probeInProgress bool
probeDone chan struct{}
mu sync.Mutex
}
func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
return &StunTurnProbe{
cacheTTL: cacheTTL,
}
}
func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if p.probeInProgress {
doneChan := p.probeDone
p.mu.Unlock()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-doneChan:
return p.getCachedResults(cacheKey, stuns, turns)
}
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
p.mu.Unlock()
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
return p.getCachedResults(cacheKey, stuns, turns)
}
// ProbeAll probes all given servers asynchronously and returns the results
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if results := p.checkCache(cacheKey); results != nil {
p.mu.Unlock()
return results
}
if p.probeInProgress {
p.mu.Unlock()
return createErrorResults(stuns, turns)
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
log.Infof("started new probe for STUN, TURN servers")
go func() {
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
}()
p.mu.Unlock()
timer := time.NewTimer(1300 * time.Millisecond)
defer timer.Stop()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-probeDone:
// when the probe is return fast, return the results right away
return p.getCachedResults(cacheKey, stuns, turns)
case <-timer.C:
// if the probe takes longer than 1.3s, return error results to avoid blocking
return createErrorResults(stuns, turns)
}
}
func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
age := time.Since(p.cacheTimestamp)
if age < p.cacheTTL {
results := append([]ProbeResult(nil), p.cacheResults...)
log.Debugf("returning cached probe results (age: %v)", age)
return results
}
}
return nil
}
func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
p.mu.Lock()
defer p.mu.Unlock()
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
return append([]ProbeResult(nil), p.cacheResults...)
}
return createErrorResults(stuns, turns)
}
func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
defer func() {
p.mu.Lock()
p.probeInProgress = false
p.mu.Unlock()
}()
results := make([]ProbeResult, len(stuns)+len(turns))
var wg sync.WaitGroup
for i, uri := range stuns {
wg.Add(1)
go func(idx int, stunURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = stunURI.String()
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
}(i, uri)
}
stunOffset := len(stuns)
for i, uri := range turns {
wg.Add(1)
go func(idx int, turnURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = turnURI.String()
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
}(stunOffset+i, uri)
}
wg.Wait()
p.mu.Lock()
p.cacheResults = results
p.cacheTimestamp = time.Now()
p.cacheKey = cacheKey
p.mu.Unlock()
log.Debug("Stored new probe results in cache")
}
// ProbeSTUN tries binding to the given STUN uri and acquiring an address
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() {
if probeErr != nil {
log.Debugf("stun probe error from %s: %s", uri, probeErr)
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}
// ProbeTURN tries allocating a session from the given TURN URI
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() {
if probeErr != nil {
log.Debugf("turn probe error from %s: %s", uri, probeErr)
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
return relayConn.LocalAddr().String(), nil
}
// ProbeAll probes all given servers asynchronously and returns the results
func ProbeAll(
ctx context.Context,
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
relays []*stun.URI,
) []ProbeResult {
results := make([]ProbeResult, len(relays))
func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
total := len(stuns) + len(turns)
results := make([]ProbeResult, total)
var wg sync.WaitGroup
for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
defer cancel()
wg.Add(1)
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
allURIs := append(append([]*stun.URI{}, stuns...), turns...)
for i, uri := range allURIs {
results[i] = ProbeResult{
URI: uri.String(),
Err: ErrCheckInProgress,
}
}
wg.Wait()
return results
}
func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
h := sha256.New()
for _, uri := range stuns {
h.Write([]byte(uri.String()))
}
for _, uri := range turns {
h.Write([]byte(uri.String()))
}
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -1,6 +1,7 @@
package common
import (
"sync/atomic"
"time"
"github.com/netbirdio/netbird/client/firewall/manager"
@@ -25,4 +26,5 @@ type HandlerParams struct {
UseNewDNSRoute bool
Firewall manager.Manager
FakeIPManager *fakeip.Manager
ForwarderPort *atomic.Uint32
}

View File

@@ -8,6 +8,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/go-multierror"
@@ -18,7 +19,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
@@ -55,6 +55,7 @@ type DnsInterceptor struct {
peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.Manager
forwarderPort *atomic.Uint32
}
func New(params common.HandlerParams) *DnsInterceptor {
@@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor {
firewall: params.Firewall,
fakeIPManager: params.FakeIPManager,
interceptedDomains: make(domainMap),
forwarderPort: params.ForwarderPort,
}
}
@@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()

View File

@@ -10,6 +10,7 @@ import (
"runtime"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -23,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
@@ -54,6 +56,7 @@ type Manager interface {
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
SetFirewall(firewall.Manager) error
SetDNSForwarderPort(port uint16)
Stop(stateManager *statemanager.Manager)
}
@@ -101,12 +104,13 @@ type DefaultManager struct {
disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler
fakeIPManager *fakeip.Manager
dnsForwarderPort atomic.Uint32
}
func NewManager(config ManagerConfig) *DefaultManager {
mCTX, cancel := context.WithCancel(config.Context)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
sysOps := systemops.New(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
@@ -130,6 +134,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
disableServerRoutes: config.DisableServerRoutes,
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
}
dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort))
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
dm.setupRefCounters(useNoop)
@@ -270,6 +275,11 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
return nil
}
// SetDNSForwarderPort sets the DNS forwarder port for route handlers
func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
m.dnsForwarderPort.Store(uint32(port))
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop()
@@ -345,6 +355,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
UseNewDNSRoute: m.useNewDNSRoute,
Firewall: m.firewall,
FakeIPManager: m.fakeIPManager,
ForwarderPort: &m.dnsForwarderPort,
}
handler := client.HandlerFromRoute(params)
if err := handler.AddRoute(m.ctx); err != nil {

View File

@@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error {
panic("implement me")
}
// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface
func (m *MockManager) SetDNSForwarderPort(port uint16) {
}
// Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop(stateManager *statemanager.Manager) {
if m.StopFunc != nil {

View File

@@ -0,0 +1,8 @@
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
package systemops
// FlushMarkedRoutes is a no-op on non-BSD platforms.
func (r *SysOps) FlushMarkedRoutes() error {
return nil
}

View File

@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData((*ExclusionCounter)(s))
sysOps := New(nil, nil)
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
sysOps.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush()
return sysOps.refCounter.Flush()
}
func (s *ShutdownState) MarshalJSON() ([]byte, error) {

View File

@@ -83,7 +83,7 @@ type SysOps struct {
localSubnetsCacheTime time.Time
}
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,

View File

@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil)
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil)
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close())
})
r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err")

View File

@@ -7,19 +7,39 @@ import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
"syscall"
"time"
"unsafe"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
)
var routeProtoFlag int
func init() {
switch os.Getenv(envRouteProtoFlag) {
case "2":
routeProtoFlag = unix.RTF_PROTO2
case "3":
routeProtoFlag = unix.RTF_PROTO3
default:
routeProtoFlag = unix.RTF_PROTO1
}
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager)
}
@@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
return r.cleanupRefCounter(stateManager)
}
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
func (r *SysOps) FlushMarkedRoutes() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("Skipping route flush: %v", err)
continue
}
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
continue
}
nexthop := Nexthop{
IP: routeInfo.Gw,
Intf: routeInfo.Interface,
}
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
continue
}
flushedCount++
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
}
if flushedCount > 0 {
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
}
@@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
}

View File

@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
data, err := os.ReadFile(m.filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist")
log.Debugf("state file %s does not exist", m.filePath)
return nil, nil // nolint:nilnil
}
return nil, fmt.Errorf("read state file: %w", err)

View File

@@ -0,0 +1,59 @@
package winregistry
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows/registry"
)
var (
advapi = syscall.NewLazyDLL("advapi32.dll")
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
)
const (
// Registry key options
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
// Registry disposition values
regCreatedNewKey = 0x1
regOpenedExistingKey = 0x2
)
// CreateVolatileKey creates a volatile registry key named path under open key root.
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
// The access parameter specifies the access rights for the key to be created.
//
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
// This provides automatic cleanup without requiring manual registry maintenance.
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
pathPtr, err := syscall.UTF16PtrFromString(path)
if err != nil {
return 0, false, err
}
var (
handle syscall.Handle
disposition uint32
)
ret, _, _ := regCreateKeyExW.Call(
uintptr(root),
uintptr(unsafe.Pointer(pathPtr)),
0, // reserved
0, // class
uintptr(regOptionVolatile), // options - volatile key
uintptr(access), // desired access
0, // security attributes
uintptr(unsafe.Pointer(&handle)),
uintptr(unsafe.Pointer(&disposition)),
)
if ret != 0 {
return 0, false, syscall.Errno(ret)
}
return registry.Key(handle), disposition == regOpenedExistingKey, nil
}

View File

@@ -17,8 +17,7 @@ type Conn struct {
ID hooks.ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn)
}
@@ -29,7 +28,7 @@ type TCPConn struct {
ID hooks.ConnectionID
}
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn)
}
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
// closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close()
cleanupConnID(id)
return err
}
// cleanupConnID executes close hooks for a connection ID.
func cleanupConnID(id hooks.ConnectionID) {
closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks {
if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err)
}
}
return err
}

View File

@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
}
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}

View File

@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
cleanupConnID(connID)
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
return fmt.Errorf("resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)

View File

@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.PacketConn.WriteTo(b, addr)
}
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn)
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.UDPConn.WriteTo(b, addr)
}
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn)

View File

@@ -1057,10 +1057,7 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus {
if msg.ShouldRunProbes {
s.runProbes()
}
s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
@@ -1070,7 +1067,7 @@ func (s *Server) Status(
return &statusResponse, nil
}
func (s *Server) runProbes() {
func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
}
@@ -1081,7 +1078,7 @@ func (s *Server) runProbes() {
}
if time.Since(s.lastProbe) > probeThreshold {
if engine.RunHealthProbes() {
if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now()
}
}

View File

@@ -10,7 +10,9 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
}
// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/version"
@@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
if !relay.Available {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
if relay.Error == probeRelay.ErrCheckInProgress.Error() {
available = "Checking..."
} else {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {

View File

@@ -296,6 +296,8 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window
connectCancel context.CancelFunc
}
type menuHandler struct {
@@ -592,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
}
}
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return nil, err
return nil, fmt.Errorf("get daemon client: %w", err)
}
activeProf, err := s.profileManager.GetActiveProfile()
if err != nil {
log.Errorf("get active profile: %v", err)
return nil, err
return nil, fmt.Errorf("get active profile: %w", err)
}
currUser, err := user.Current()
@@ -610,84 +610,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
return nil, fmt.Errorf("get current user: %w", err)
}
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
loginResp, err := conn.Login(ctx, &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name,
Username: &currUser.Username,
})
if err != nil {
log.Errorf("login to management URL with: %v", err)
return nil, err
return nil, fmt.Errorf("login to management: %w", err)
}
if loginResp.NeedsSSOLogin && openURL {
err = s.handleSSOLogin(loginResp, conn)
if err != nil {
log.Errorf("handle SSO login failed: %v", err)
return nil, err
if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil {
return nil, fmt.Errorf("SSO login: %w", err)
}
}
return loginResp, nil
}
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
err := openURL(loginResp.VerificationURIComplete)
if err != nil {
log.Errorf("opening the verification uri in the browser failed: %v", err)
return err
func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
if err := openURL(loginResp.VerificationURIComplete); err != nil {
return fmt.Errorf("open browser: %w", err)
}
resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil {
log.Errorf("waiting sso login failed with: %v", err)
return err
return fmt.Errorf("wait for SSO login: %w", err)
}
if resp.Email != "" {
err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
Email: resp.Email,
})
if err != nil {
log.Warnf("failed to set profile state: %v", err)
}); err != nil {
log.Debugf("failed to set profile state: %v", err)
} else {
s.mProfile.refresh()
}
}
return nil
}
func (s *serviceClient) menuUpClick() error {
func (s *serviceClient) menuUpClick(ctx context.Context) error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
systray.SetTemplateIcon(iconErrorMacOS, s.icError)
log.Errorf("get client: %v", err)
return err
return fmt.Errorf("get daemon client: %w", err)
}
_, err = s.login(true)
_, err = s.login(ctx, true)
if err != nil {
log.Errorf("login failed with: %v", err)
return err
return fmt.Errorf("login: %w", err)
}
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
return err
return fmt.Errorf("get status: %w", err)
}
if status.Status == string(internal.StatusConnected) {
log.Warnf("already connected")
return nil
}
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Errorf("up service: %v", err)
return err
if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("start connection: %w", err)
}
return nil
@@ -697,24 +684,20 @@ func (s *serviceClient) menuDownClick() error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return err
return fmt.Errorf("get daemon client: %w", err)
}
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
return err
return fmt.Errorf("get status: %w", err)
}
if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) {
log.Warnf("already down")
return nil
}
if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("down service: %v", err)
return err
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
return fmt.Errorf("stop connection: %w", err)
}
return nil
@@ -850,6 +833,7 @@ func (s *serviceClient) onTrayReady() {
newProfileMenuArgs := &newProfileMenuArgs{
ctx: s.ctx,
serviceClient: s,
profileManager: s.profileManager,
eventHandler: s.eventHandler,
profileMenuItem: profileMenuItem,
@@ -1381,7 +1365,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return
}
resp, err := s.login(false)
resp, err := s.login(ctx, false)
if err != nil {
log.Errorf("failed to fetch login URL: %v", err)
return
@@ -1401,7 +1385,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return
}
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
_, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
if err != nil {
log.Errorf("Waiting sso login failed with: %v", err)
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
@@ -1409,7 +1393,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
}
label.SetText("Re-authentication successful.\nReconnecting")
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
return
@@ -1422,7 +1406,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return
}
_, err = conn.Up(s.ctx, &proto.UpRequest{})
_, err = conn.Up(ctx, &proto.UpRequest{})
if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
log.Errorf("Reconnecting failed with: %v", err)

View File

@@ -18,6 +18,7 @@ import (
"github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types"
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
return "", err
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get post-up status: %v", err)
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err)
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err)
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}

View File

@@ -12,6 +12,8 @@ import (
"fyne.io/fyne/v2"
"fyne.io/systray"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
@@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) {
func (h *eventHandler) handleConnectClick() {
h.client.mUp.Disable()
if h.client.connectCancel != nil {
h.client.connectCancel()
}
connectCtx, connectCancel := context.WithCancel(h.client.ctx)
h.client.connectCancel = connectCancel
go func() {
defer h.client.mUp.Enable()
if err := h.client.menuUpClick(); err != nil {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
defer connectCancel()
if err := h.client.menuUpClick(connectCtx); err != nil {
st, ok := status.FromError(err)
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
log.Debugf("connect operation cancelled by user")
} else {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
log.Errorf("connect failed: %v", err)
}
}
if err := h.client.updateStatus(); err != nil {
log.Debugf("failed to update status after connect: %v", err)
}
}()
}
func (h *eventHandler) handleDisconnectClick() {
h.client.mDown.Disable()
if h.client.connectCancel != nil {
log.Debugf("cancelling ongoing connect operation")
h.client.connectCancel()
h.client.connectCancel = nil
}
go func() {
defer h.client.mDown.Enable()
if err := h.client.menuDownClick(); err != nil {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon"))
st, ok := status.FromError(err)
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
log.Errorf("disconnect failed: %v", err)
} else {
log.Debugf("disconnect cancelled or already disconnecting")
}
}
if err := h.client.updateStatus(); err != nil {
log.Debugf("failed to update status after disconnect: %v", err)
}
}()
}
@@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
}
h.client.getSrvConfig()
return nil
}

View File

@@ -387,6 +387,7 @@ type subItem struct {
type profileMenu struct {
mu sync.Mutex
ctx context.Context
serviceClient *serviceClient
profileManager *profilemanager.ProfileManager
eventHandler *eventHandler
profileMenuItem *systray.MenuItem
@@ -396,7 +397,7 @@ type profileMenu struct {
logoutSubItem *subItem
profilesState []Profile
downClickCallback func() error
upClickCallback func() error
upClickCallback func(context.Context) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -404,12 +405,13 @@ type profileMenu struct {
type newProfileMenuArgs struct {
ctx context.Context
serviceClient *serviceClient
profileManager *profilemanager.ProfileManager
eventHandler *eventHandler
profileMenuItem *systray.MenuItem
emailMenuItem *systray.MenuItem
downClickCallback func() error
upClickCallback func() error
upClickCallback func(context.Context) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -418,6 +420,7 @@ type newProfileMenuArgs struct {
func newProfileMenu(args newProfileMenuArgs) *profileMenu {
p := profileMenu{
ctx: args.ctx,
serviceClient: args.serviceClient,
profileManager: args.profileManager,
eventHandler: args.eventHandler,
profileMenuItem: args.profileMenuItem,
@@ -569,10 +572,19 @@ func (p *profileMenu) refresh() {
}
}
if err := p.upClickCallback(); err != nil {
if p.serviceClient.connectCancel != nil {
p.serviceClient.connectCancel()
}
connectCtx, connectCancel := context.WithCancel(p.ctx)
p.serviceClient.connectCancel = connectCancel
if err := p.upClickCallback(connectCtx); err != nil {
log.Errorf("failed to handle up click after switching profile: %v", err)
}
connectCancel()
p.refresh()
p.loadSettingsCallback()
}

View File

@@ -19,6 +19,10 @@ const (
RootZone = "."
// DefaultClass is the class supported by the system
DefaultClass = "IN"
// ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort.
ForwarderClientPort uint16 = 5353
// ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here.
ForwarderServerPort uint16 = 22054
)
const invalidHostLabel = "[^a-zA-Z0-9-]+"
@@ -31,6 +35,8 @@ type Config struct {
NameServerGroups []*NameServerGroup
// CustomZones contains a list of custom zone
CustomZones []CustomZone
// ForwarderPort is the port clients should connect to on routing peers for DNS forwarding
ForwarderPort uint16
}
// CustomZone represents a custom zone to be resolved by the dns server

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
echo ""
export NETBIRD_SIGNAL_PROTOCOL="https"
unset NETBIRD_LETSENCRYPT_DOMAIN
unset NETBIRD_MGMT_API_CERT_FILE
unset NETBIRD_MGMT_API_CERT_KEY_FILE
fi
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
export NETBIRD_SIGNAL_PROTOCOL="https"
fi
# Check if management identity provider is set
if [ -n "$NETBIRD_MGMT_IDP" ]; then
EXTRA_CONFIG={}

View File

@@ -40,13 +40,21 @@ services:
signal:
<<: *default
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
depends_on:
- dashboard
volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
ports:
- $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
command: [
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
"--log-file", "console"
]
# Relay
relay:

View File

@@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager {
return integrations.InitPermissionsManager(s.Store())
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
s.AfterInit(func(s *BaseServer) {
manager.SetAccountManager(s.AccountManager())
})
return manager
})
}

View File

@@ -109,7 +109,7 @@ type Manager interface {
GetIdpManager() idp.Manager
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error

View File

@@ -21,8 +21,8 @@ import (
)
const (
dnsForwarderPort = 22054
oldForwarderPort = 5353
dnsForwarderPort = nbdns.ForwarderServerPort
oldForwarderPort = nbdns.ForwarderClientPort
)
const dnsForwarderPortMinVersion = "v0.59.0"
@@ -196,7 +196,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
if len(peers) == 0 {
return oldForwarderPort
return int64(oldForwarderPort)
}
reqVer := semver.Canonical(requiredVersion)
@@ -211,17 +211,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" {
// If any peer doesn't have version info, return 0
return oldForwarderPort
return int64(oldForwarderPort)
}
// Compare versions
if semver.Compare(peerVersion, reqVer) < 0 {
return oldForwarderPort
return int64(oldForwarderPort)
}
}
// All peers have the required version or newer
return dnsForwarderPort
return int64(dnsForwarderPort)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache

View File

@@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, dnsForwarderPort)
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
}
})
@@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache, dnsForwarderPort)
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
}
})
}
@@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort)
result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
@@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) {
// Test with empty peers list
peers := []*nbpeer.Peer{}
result := computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort {
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
}
@@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) {
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort {
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
}
@@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) {
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != dnsForwarderPort {
if result != int64(dnsForwarderPort) {
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
}
@@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) {
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort {
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
}
@@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) {
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort {
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
}
@@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) {
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result == oldForwarderPort {
if result == int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
}
@@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) {
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort {
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
}
}

View File

@@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
@@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
}
_, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid))
reason := invalidPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
_, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid))
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
h.setApprovalRequiredFlag(respBody, validPeersMap)
h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap)
util.WriteJSONObject(r.Context(), w, respBody)
}
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) {
for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id]
_, ok := validPeersMap[peer.Id]
if !ok {
peer.ApprovalRequired = true
reason := invalidPeersMap[peer.Id]
peer.DisapprovalReason = &reason
}
}
}
@@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
}
}
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}
}
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
}
return &api.Peer{
apiPeer := &api.Peer{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
Name: peer.Name,
@@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
}
if !approved {
apiPeer.DisapprovalReason = &reason
}
return apiPeer
}
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {

View File

@@ -7,9 +7,10 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"

View File

@@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
return true, nil
}
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
var err error
var groups []*types.Group
var peers []*nbpeer.Peer
@@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
return nil, nil, err
}
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return nil, err
return nil, nil, err
}
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
return nil, nil, err
}
return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
if err != nil {
return nil, nil, err
}
invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra)
if err != nil {
return nil, nil, err
}
return validPeers, invalidPeers, nil
}
type MockIntegratedValidator struct {
@@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil
}
func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) {
return make(map[string]string), nil
}
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
return peer
}

View File

@@ -15,6 +15,7 @@ type IntegratedValidator interface {
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
Stop(ctx context.Context)

View File

@@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
return nil, err
return nil, nil, err
}
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
approvedPeers[id] = struct{}{}
}
return approvedPeers, nil
return approvedPeers, nil, nil
}
// GetGroup mock implementation of GetGroup from server.AccountManager interface

View File

@@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) {
}
dnsCache := &DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort)
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
assert.NotNil(t, response)
// assert peer config

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -22,6 +23,7 @@ type Manager interface {
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
SetAccountManager(accountManager account.Manager)
}
type managerImpl struct {
@@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
return permissions, nil
}
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
// no-op
}

View File

@@ -9,6 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
account "github.com/netbirdio/netbird/management/server/account"
modules "github.com/netbirdio/netbird/management/server/permissions/modules"
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
@@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
}
// SetAccountManager mocks base method.
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetAccountManager", accountManager)
}
// SetAccountManager indicates an expected call of SetAccountManager.
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
}
// ValidateAccountAccess mocks base method.
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
m.ctrl.T.Helper()

View File

@@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
@@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool {
}
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
peerIPs := make(map[string]struct{})
@@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
peerIPs[peerToConnect.IP.String()] = struct{}{}
}
for _, expiredPeer := range expiredPeers {
peerIPs[expiredPeer.IP.String()] = struct{}{}
}
for _, record := range customZone.Records {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)

View File

@@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
peer *nbpeer.Peer
customZone nbdns.CustomZone
peersToConnect []*nbpeer.Peer
expiredPeers []*nbpeer.Peer
expectedRecords []nbdns.SimpleRecord
}{
{
@@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
},
},
peersToConnect: []*nbpeer.Peer{},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
@@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
}
return peers
}(),
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
@@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
@@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
{
name: "expired peers are included in DNS entries",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
peersToConnect: []*nbpeer.Peer{
{ID: "peer1", IP: net.ParseIP("10.0.0.1")},
},
expiredPeers: []*nbpeer.Peer{
{ID: "expired-peer", IP: net.ParseIP("10.0.0.99")},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers)
assert.Equal(t, len(tt.expectedRecords), len(result))
assert.ElementsMatch(t, tt.expectedRecords, result)
})

View File

@@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then
NETBIRD_RELEASE=latest
fi
TAG_NAME=""
get_release() {
local RELEASE=$1
if [ "$RELEASE" = "latest" ]; then
@@ -38,17 +40,19 @@ get_release() {
local TAG="tags/${RELEASE}"
local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
fi
OUTPUT=""
if [ -n "$GITHUB_TOKEN" ]; then
curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}")
else
curl -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
OUTPUT=$(curl -s "${URL}")
fi
TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1)
echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+'
}
download_release_binary() {
VERSION=$(get_release "$NETBIRD_RELEASE")
echo "Using the following tag name for binary installation: ${TAG_NAME}"
BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download"
BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz"

View File

@@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
return fmt.Errorf("create connection: %w", err)
}
return nil
}

View File

@@ -463,6 +463,9 @@ components:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
example: true
disapproval_reason:
description: (Cloud only) Reason why the peer requires approval
type: string
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:

View File

@@ -1037,6 +1037,9 @@ type Peer struct {
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
// DisapprovalReason (Cloud only) Reason why the peer requires approval
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
@@ -1124,6 +1127,9 @@ type PeerBatch struct {
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
// DisapprovalReason (Cloud only) Reason why the peer requires approval
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`

View File

@@ -410,7 +410,7 @@ message DNSConfig {
bool ServiceEnable = 1;
repeated NameServerGroup NameServerGroups = 2;
repeated CustomZone CustomZones = 3;
int64 ForwarderPort = 4;
int64 ForwarderPort = 4 [deprecated = true];
}
// CustomZone represents a dns.CustomZone

View File

@@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
return fmt.Errorf("create connection: %w", err)
}
return nil
}

View File

@@ -94,7 +94,7 @@ var (
startPprof()
opts, certManager, err := getTLSConfigurations()
opts, certManager, tlsConfig, err := getTLSConfigurations()
if err != nil {
return err
}
@@ -132,7 +132,7 @@ var (
// Start the main server - always serve HTTP with WebSocket proxy support
// If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager
if certManager == nil {
if tlsConfig == nil {
// Without TLS, serve plain HTTP
httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort))
if err != nil {
@@ -140,9 +140,10 @@ var (
}
log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String())
serveHTTP(httpListener, grpcRootHandler)
} else if signalPort != 443 {
// With TLS but not on port 443, serve HTTPS
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig())
} else if certManager == nil || signalPort != 443 {
// Serve HTTPS if not already handled by startServerWithCertManager
// (custom certificates or Let's Encrypt with custom port)
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig)
if err != nil {
return err
}
@@ -202,7 +203,7 @@ func startPprof() {
}()
}
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) {
var (
err error
certManager *autocert.Manager
@@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
log.Infof("running without TLS")
return nil, nil, nil
return nil, nil, nil, nil
}
if signalLetsencryptDomain != "" {
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
if err != nil {
return nil, certManager, err
return nil, certManager, nil, err
}
tlsConfig = certManager.TLSConfig()
log.Infof("setting up TLS with LetsEncrypt.")
} else {
if signalCertFile == "" || signalCertKey == "" {
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
}
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
return nil, certManager, err
return nil, certManager, nil, err
}
log.Infof("setting up TLS with custom certificates.")
}
transportCredentials := credentials.NewTLS(tlsConfig)
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err
}
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {