Merge branch 'main' into refactor/getaccount-raw

This commit is contained in:
crn4
2025-10-30 18:07:17 +01:00
79 changed files with 2122 additions and 425 deletions

View File

@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
if err != nil { if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
} }
@@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
func getStatusOutput(cmd *cobra.Command, anon bool) string { func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string var statusOutputString string
statusResp, err := getStatus(cmd.Context()) statusResp, err := getStatus(cmd.Context(), true)
if err != nil { if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
) )
} }
return statusOutputString return statusOutputString

View File

@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx) resp, err := getStatus(ctx, false)
if err != nil { if err != nil {
return err return err
} }
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatus(ctx context.Context) (*proto.StatusResponse, error) { func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
} }
defer conn.Close() defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil { if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
} }

View File

@@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes) return m.router.UpdateSet(set, prefixes)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = ruleInfo.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {
if port == nil { if port == nil {
return nil return nil

View File

@@ -151,14 +151,20 @@ type Manager interface {
DisableRouting() error DisableRouting() error
// AddDNATRule adds a DNAT rule // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
AddDNATRule(ForwardRule) (Rule, error) AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule // DeleteDNATRule deletes the outbound DNAT rule.
DeleteDNATRule(Rule) error DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes // UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes) return m.router.UpdateSet(set, prefixes)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets // applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork( func (r *router) applyNetwork(
network firewall.Network, network firewall.Network,

View File

@@ -22,6 +22,8 @@ type BaseConnTrack struct {
PacketsRx atomic.Uint64 PacketsRx atomic.Uint64
BytesTx atomic.Uint64 BytesTx atomic.Uint64
BytesRx atomic.Uint64 BytesRx atomic.Uint64
DNATOrigPort atomic.Uint32
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler

View File

@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
if exists { if exists {
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
return key, true return key, uint16(conn.DNATOrigPort.Load()), true
} }
return key, false return key, 0, false
} }
// TrackOutbound records an outbound TCP connection // TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
// if (inverted direction) conn is not tracked, track this direction return origPort
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
} }
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
return 0
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 { if exists || flags&TCPSyn == 0 {
return return
} }
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false) conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew)) conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
t.logger.Trace2("New %s TCP connection: %s", direction, key) if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
t.mutex.Lock() t.mutex.Lock()
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
} }
} }
// GetConnection safely retrieves a connection state
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.tickerCancel() t.tickerCancel()

View File

@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
serverPort := uint16(80) serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound) // 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
key := ConnKey{ key := ConnKey{
SrcIP: clientIP, SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake // 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer // 4. Test data transfer
// Client sends data // Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
// Server sends ACK for data // Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data // Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
// Verify state and counters // Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState()) require.Equal(t, TCPStateEstablished, conn.GetState())

View File

@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
// if (inverted direction) conn is not tracked, track this direction if exists {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) return origPort
} }
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
return 0
} }
// TrackInbound records an inbound UDP connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
} }
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists { if exists {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
return key, true return key, uint16(conn.DNATOrigPort.Load()), true
} }
return key, false return key, 0, false
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists { if exists {
return return
} }
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
} }
conn.DNATOrigPort.Store(uint32(origPort))
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace2("New %s UDP connection: %s", direction, key) if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.sendEvent(nftypes.TypeStart, conn, ruleID) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }

View File

@@ -109,6 +109,10 @@ type Manager struct {
dnatMappings map[netip.Addr]netip.Addr dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex dnatMutex sync.RWMutex
dnatBiMap *biDNATMap dnatBiMap *biDNATMap
portDNATEnabled atomic.Bool
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex
} }
// decoder for packages // decoder for packages
@@ -122,6 +126,8 @@ type decoder struct {
icmp6 layers.ICMPv6 icmp6 layers.ICMPv6
decoded []gopacket.LayerType decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
@@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr), dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
} }
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
@@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true return true
} }
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, packetData, size)
m.translateOutboundDNAT(packetData, d) m.translateOutboundDNAT(packetData, d)
return false return false
@@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
if origPort == 0 {
break
}
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite UDP port: %v", err)
}
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
if origPort == 0 {
break
}
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite TCP port: %v", err)
}
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
} }
@@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
} }
d.dnatOrigPort = 0
} }
// udpHooksDrop checks if any UDP hooks should drop the packet // udpHooksDrop checks if any UDP hooks should drop the packet
@@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false return false
} }
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if translated := m.translateInboundReverse(packetData, d); translated { if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses // Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true return true
} }
srcIP, dstIP = m.extractIPs(d) srcIP, dstIP = m.extractIPs(d)

View File

@@ -50,6 +50,8 @@ type logMessage struct {
arg4 any arg4 any
arg5 any arg5 any
arg6 any arg6 any
arg7 any
arg8 any
} }
// Logger is a high-performance, non-blocking logger // Logger is a high-performance, non-blocking logger
@@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
func (l *Logger) Error(format string) { func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
@@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
} }
} }
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
@@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
} }
} }
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0] *buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
@@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
argCount++ argCount++
if msg.arg6 != nil { if msg.arg6 != nil {
argCount++ argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
} }
} }
} }
@@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6: case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
case 7:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
case 8:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
} }
*buf = append(*buf, formatted...) *buf = append(*buf, formatted...)
@@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error {
case <-done: case <-done:
return nil return nil
} }
} }

View File

@@ -5,7 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -13,6 +15,21 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
const (
// Port offsets in TCP/UDP headers
sourcePortOffset = 0
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 { func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 { if len(header) < 20 {
return 0 return 0
@@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// icmpChecksum calculates ICMP checksum.
func icmpChecksum(data []byte) uint16 { func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32 var sum1, sum2, sum3, sum4 uint32
i := 0 i := 0
@@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// biDNATMap maintains bidirectional DNAT mappings.
type biDNATMap struct { type biDNATMap struct {
forward map[netip.Addr]netip.Addr forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr
} }
// portDNATRule represents a port-specific DNAT rule.
type portDNATRule struct {
protocol gopacket.LayerType
origPort uint16
targetPort uint16
targetIP netip.Addr
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
func newBiDNATMap() *biDNATMap { func newBiDNATMap() *biDNATMap {
return &biDNATMap{ return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr), forward: make(map[netip.Addr]netip.Addr),
@@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap {
} }
} }
// set adds a bidirectional DNAT mapping between original and translated addresses.
func (b *biDNATMap) set(original, translated netip.Addr) { func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated b.forward[original] = translated
b.reverse[translated] = original b.reverse[translated] = original
} }
// delete removes a bidirectional DNAT mapping for the given original address.
func (b *biDNATMap) delete(original netip.Addr) { func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists { if translated, exists := b.forward[original]; exists {
delete(b.forward, original) delete(b.forward, original)
@@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
} }
} }
// getTranslated returns the translated address for a given original address.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original] translated, exists := b.forward[original]
return translated, exists return translated, exists
} }
// getOriginal returns the original address for a given translated address.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated] original, exists := b.reverse[translated]
return original, exists return original, exists
} }
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() { if !originalAddr.IsValid() {
return fmt.Errorf("invalid IP addresses") return fmt.Errorf("invalid original IP address")
}
if !translatedAddr.IsValid() {
return fmt.Errorf("invalid translated IP address")
} }
if m.localipmanager.IsLocalIP(translatedAddr) { if m.localipmanager.IsLocalIP(translatedAddr) {
@@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock() m.dnatMutex.Lock()
defer m.dnatMutex.Unlock() defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil { if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap() m.dnatBiMap = newBiDNATMap()
@@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
return nil return nil
} }
// RemoveInternalDNATMapping removes a 1:1 IP address mapping // RemoveInternalDNATMapping removes a 1:1 IP address mapping.
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock() m.dnatMutex.Lock()
defer m.dnatMutex.Unlock() defer m.dnatMutex.Unlock()
@@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil return nil
} }
// getDNATTranslation returns the translated address if a mapping exists // getDNATTranslation returns the translated address if a mapping exists.
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return addr, false return addr, false
@@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists return translated, exists
} }
// findReverseDNATMapping finds original address for return traffic // findReverseDNATMapping finds original address for return traffic.
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return translatedAddr, false return translatedAddr, false
@@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists return original, exists
} }
// translateOutboundDNAT applies DNAT translation to outbound packets // translateOutboundDNAT applies DNAT translation to outbound packets.
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return false return false
} }
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP) translatedIP, exists := m.getDNATTranslation(dstIP)
@@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false return false
} }
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
m.logger.Error1("Failed to rewrite packet destination: %v", err) m.logger.Error1("failed to rewrite packet destination: %v", err)
return false return false
} }
@@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true return true
} }
// translateInboundReverse applies reverse DNAT to inbound return traffic // translateInboundReverse applies reverse DNAT to inbound return traffic.
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return false return false
} }
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP) originalIP, exists := m.findReverseDNATMapping(srcIP)
@@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false return false
} }
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
m.logger.Error1("Failed to rewrite packet source: %v", err) m.logger.Error1("failed to rewrite packet source: %v", err)
return false return false
} }
@@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true return true
} }
// rewritePacketDestination replaces destination IP in the packet // rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { if !newIP.Is4() {
return ErrIPv4Only return ErrIPv4Only
} }
var oldDst [4]byte var oldIP [4]byte
copy(oldDst[:], packetData[16:20]) copy(oldIP[:], packetData[ipOffset:ipOffset+4])
newDst := newIP.As4() newIPBytes := newIP.As4()
copy(packetData[16:20], newDst[:]) copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
ipHeaderLen := int(d.ip4.IHL) * 4 ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length") return errInvalidIPHeaderLength
} }
binary.BigEndian.PutUint16(packetData[10:12], 0) binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
if len(d.decoded) > 1 { if len(d.decoded) > 1 {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen) m.updateICMPChecksum(packetData, ipHeaderLen)
} }
@@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
return nil return nil
} }
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 { if len(packetData) < tcpStart+18 {
@@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
} }
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen udpStart := ipHeaderLen
if len(packetData) < udpStart+8 { if len(packetData) < udpStart+8 {
@@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
} }
// updateICMPChecksum recalculates ICMP checksum after packet modification.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 { if len(packetData) < icmpStart+8 {
@@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum) binary.BigEndian.PutUint16(icmpData[2:4], checksum)
} }
// incrementalUpdate performs incremental checksum update per RFC 1624 // incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum) sum := uint32(^oldChecksum)
@@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) // AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return nil, errNatNotSupported return nil, errNatNotSupported
@@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
return m.nativeFirewall.AddDNATRule(rule) return m.nativeFirewall.AddDNATRule(rule)
} }
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) // DeleteDNATRule deletes outbound DNAT rule.
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errNatNotSupported return errNatNotSupported
} }
return m.nativeFirewall.DeleteDNATRule(rule) return m.nativeFirewall.DeleteDNATRule(rule)
} }
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATRules = append(m.portDNATRules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATRules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {
return false
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort := uint16(d.tcp.DstPort)
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
case layers.LayerTypeUDP:
dstPort := uint16(d.udp.DstPort)
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
default:
return false
}
}
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATRules {
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
continue
}
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
return false
}
if rule.origPort != port {
continue
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort
return true
}
return false
}
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
portStart := tcpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}
portStart := udpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
checksumOffset := udpStart + 6
if len(packetData) >= udpStart+8 {
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum != 0 {
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
}
return nil
}

View File

@@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
} }
}) })
} }
// BenchmarkPortDNAT measures the performance of port DNAT operations
func BenchmarkPortDNAT(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
useMatchPort bool
description string
}{
{
name: "tcp_inbound_dnat_match",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: true,
description: "TCP inbound port DNAT translation (22 → 22022)",
},
{
name: "tcp_inbound_dnat_nomatch",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: false,
description: "TCP inbound with DNAT configured but no port match",
},
{
name: "tcp_inbound_no_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
useMatchPort: false,
description: "TCP inbound without DNAT (baseline)",
},
{
name: "udp_inbound_dnat_match",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: true,
description: "UDP inbound port DNAT translation (5353 → 22054)",
},
{
name: "udp_inbound_dnat_nomatch",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: false,
description: "UDP inbound with DNAT configured but no port match",
},
{
name: "udp_inbound_no_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
useMatchPort: false,
description: "UDP inbound without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
var origPort, targetPort, testPort uint16
if sc.proto == layers.IPProtocolTCP {
origPort, targetPort = 22, 22022
} else {
origPort, targetPort = 5353, 22054
}
if sc.useMatchPort {
testPort = origPort
} else {
testPort = 443 // Different port
}
// Setup port DNAT mapping if needed
if sc.setupDNAT {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
require.NoError(b, err)
}
// Pre-establish inbound connection for outbound reverse test
if sc.setupDNAT && sc.useMatchPort {
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
manager.filterInbound(inboundPacket, 0)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark inbound DNAT translation
b.Run("inbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
manager.filterInbound(packet, 0)
}
})
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
if sc.setupDNAT && sc.useMatchPort {
b.Run("outbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh return packet (from target port)
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
manager.filterOutbound(packet, 0)
}
})
}
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) {
err = manager.RemoveInternalDNATMapping(originalIP) err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping") require.Error(t, err, "Should error when removing non-existent mapping")
} }
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
testCases := []struct {
name string
protocol layers.IPProtocol
sourcePort uint16
targetPort uint16
}{
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
d := parsePacket(t, inboundPacket)
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
require.True(t, translated, "Inbound packet should be translated")
d = parsePacket(t, inboundPacket)
var dstPort uint16
switch tc.protocol {
case layers.IPProtocolTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.IPProtocolUDP:
dstPort = uint16(d.udp.DstPort)
}
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
})
}
}
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
d := parsePacket(t, packet)
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})
}
}
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
switch proto {
case layers.IPProtocolTCP:
return firewall.ProtocolTCP
case layers.IPProtocolUDP:
return firewall.ProtocolUDP
default:
return firewall.ProtocolALL
}
}

View File

@@ -16,25 +16,33 @@ type PacketStage int
const ( const (
StageReceived PacketStage = iota StageReceived PacketStage = iota
StageInboundPortDNAT
StageInbound1to1NAT
StageConntrack StageConntrack
StagePeerACL StagePeerACL
StageRouting StageRouting
StageRouteACL StageRouteACL
StageForwarding StageForwarding
StageCompleted StageCompleted
StageOutbound1to1NAT
StageOutboundPortReverse
) )
const msgProcessingCompleted = "Processing completed" const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string { func (s PacketStage) String() string {
return map[PacketStage]string{ return map[PacketStage]string{
StageReceived: "Received", StageReceived: "Received",
StageConntrack: "Connection Tracking", StageInboundPortDNAT: "Inbound Port DNAT",
StagePeerACL: "Peer ACL", StageInbound1to1NAT: "Inbound 1:1 NAT",
StageRouting: "Routing", StageConntrack: "Connection Tracking",
StageRouteACL: "Route ACL", StagePeerACL: "Peer ACL",
StageForwarding: "Forwarding", StageRouting: "Routing",
StageCompleted: "Completed", StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageOutbound1to1NAT: "Outbound 1:1 NAT",
StageOutboundPortReverse: "Outbound DNAT Reverse",
}[s] }[s]
} }
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
return trace
}
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
} }
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
m.handleOutboundDNAT(trace, packetData, d)
dropped := m.filterOutbound(packetData, 0) dropped := m.filterOutbound(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
} }
return trace return trace
} }
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
trace.DestinationPort = m.getDestPort(d)
}
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
}
return false
}
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
return false
}
protocol := d.decoded[1]
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var originalPort uint16
if protocol == layers.LayerTypeTCP {
originalPort = uint16(d.tcp.DstPort)
} else {
originalPort = uint16(d.udp.DstPort)
}
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
if translated {
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
protoStr := "TCP"
if protocol == layers.LayerTypeUDP {
protoStr = "UDP"
}
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
trace.AddResult(StageInboundPortDNAT, msg, true)
return true
}
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
return false
}
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
trace.AddResult(StageInbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
m.traceOutbound1to1NAT(trace, packetData, d)
m.traceOutboundPortReverse(trace, packetData, d)
}
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatMappings[dstIP]
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
trace.AddResult(StageOutbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var origPort uint16
transport := d.decoded[1]
switch transport {
case layers.LayerTypeTCP:
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
case layers.LayerTypeUDP:
srcPort := uint16(d.udp.SrcPort)
dstPort := uint16(d.udp.DstPort)
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
default:
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
return false
}
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
return false
}
func (m *Manager) getDestPort(d *decoder) uint16 {
if len(d.decoded) < 2 {
return 0
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.DstPort)
default:
return 0
}
}

View File

@@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageCompleted, StageCompleted,
@@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageCompleted, StageCompleted,
}, },
@@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted, StageCompleted,
}, },
expectedAllow: true, expectedAllow: true,
@@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
StageCompleted, StageCompleted,

View File

@@ -4,12 +4,15 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"runtime" "runtime"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@@ -17,6 +20,9 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
) )
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls // Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff { func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff() b := backoff.NewExponentialBackOff()
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx) return backoff.WithContext(b, ctx)
} }
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options. // CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
})) }))
} }
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) conn, err := grpc.NewClient(
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr, addr,
transportOption, transportOption,
WithCustomDialer(tlsEnabled, component), WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
}), }),
) )
if err != nil { if err != nil {
log.Printf("DialContext error: %v", err) return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err return nil, err
} }

View File

@@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { func WithCustomDialer(_ bool, _ string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
currentUser, err := user.Current() currentUser, err := user.Current()
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
} }
return conn, nil return conn, nil

View File

@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client. config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states. state.json: Anonymized client state dump containing netbird states for the active profile.
mutex.prof: Mutex profiling information. mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information. goroutine.prof: Goroutine profiling information.
block.prof: Block profiling information. block.prof: Block profiling information.
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
return nil return nil
} }
log.Debugf("Adding state file from: %s", path)
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
} }
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var ( var (
searchDomains []string searchDomains []string
matchDomains []string matchDomains []string
) )
err = s.recordSystemDNSSettings(true) if err := s.recordSystemDNSSettings(true); err != nil {
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error()) log.Errorf("unable to update record of System's DNS config: %s", err.Error())
} }
if config.RouteAll { if config.RouteAll {
searchDomains = append(searchDomains, "\"\"") searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS() if err := s.addLocalDNS(); err != nil {
if err != nil { log.Warnf("failed to add local DNS: %v", err)
log.Infof("failed to enable split DNS")
} }
s.updateState(stateManager)
} }
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else { } else {
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil { if err != nil {
return fmt.Errorf("add match domains: %w", err) return fmt.Errorf("add match domains: %w", err)
} }
s.updateState(stateManager)
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 { if len(searchDomains) != 0 {
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil { if err != nil {
return fmt.Errorf("add search domains: %w", err) return fmt.Errorf("add search domains: %w", err)
} }
s.updateState(stateManager)
if err := s.flushDNSCache(); err != nil { if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err) log.Errorf("failed to flush DNS cache: %v", err)
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
return nil return nil
} }
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (s *systemConfigurator) string() string { func (s *systemConfigurator) string() string {
return "scutil" return "scutil"
} }
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error { func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
if err := s.recordSystemDNSSettings(true); err != nil { if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("Unable to get system DNS configuration")
return fmt.Errorf("recordSystemDNSSettings(): %w", err) return fmt.Errorf("recordSystemDNSSettings(): %w", err)
} }
} }
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server") log.Info("Not enabling local DNS server")
return nil
}
if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
} }
return nil return nil

View File

@@ -0,0 +1,111 @@
//go:build !ios
package dns
import (
"context"
"net/netip"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
defer func() {
require.NoError(t, sm.Stop(context.Background()))
}()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
config := HostDNSConfig{
ServerIP: netip.MustParseAddr("100.64.0.1"),
ServerPort: 53,
RouteAll: true,
Domains: []DomainConfig{
{Domain: "example.com", MatchOnly: true},
},
}
err := configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
require.NoError(t, sm.PersistState(context.Background()))
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
defer func() {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}()
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
t.Logf("Key %s exists before cleanup", key)
}
}
sm2 := statemanager.New(stateFile)
sm2.RegisterState(&ShutdownState{})
err = sm2.LoadState(&ShutdownState{})
require.NoError(t, err)
state := sm2.GetState(&ShutdownState{})
if state == nil {
t.Skip("State not saved, skipping cleanup test")
}
shutdownState, ok := state.(*ShutdownState)
require.True(t, ok)
err = shutdownState.Cleanup()
require.NoError(t, err)
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}
func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
output, err := cmd.CombinedOutput()
if err != nil {
if strings.Contains(string(output), "No such key") {
return false, nil
}
return false, err
}
return !strings.Contains(string(output), "No such key"), nil
}
func removeTestDNSKey(key string) error {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
_, err := cmd.CombinedOutput()
return err
}

View File

@@ -17,6 +17,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/winregistry"
) )
var ( var (
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{ r.updateState(stateManager)
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var searchDomains, matchDomains []string var searchDomains, matchDomains []string
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
} }
if err := r.removeDNSMatchPolicies(); err != nil {
log.Errorf("cleanup old dns match policies: %s", err)
}
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil { if err != nil {
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
r.nrptEntryCount = count r.nrptEntryCount = count
} else { } else {
if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err)
}
r.nrptEntryCount = 0 r.nrptEntryCount = 0
} }
if err := stateManager.UpdateState(&ShutdownState{ r.updateState(stateManager)
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
if err := r.updateSearchDomains(searchDomains); err != nil { if err := r.updateSearchDomains(searchDomains); err != nil {
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil return nil
} }
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err) return fmt.Errorf("adding dns setup for all failed: %w", err)
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
return fmt.Errorf("remove existing dns policy: %w", err) return fmt.Errorf("remove existing dns policy: %w", err)
} }
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
if err != nil { if err != nil {
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
} }
defer closer(regKey) defer closer(regKey)

View File

@@ -0,0 +1,102 @@
package dns
import (
"fmt"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/registry"
)
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config5, nil)
require.NoError(t, err)
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after first config", i)
}
config2 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config2, nil)
require.NoError(t, err)
// Verify first 2 entries exist
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after second config", i)
}
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
}
func registryKeyExists(path string) (bool, error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
if err == registry.ErrNotExist {
return false, nil
}
return false, err
}
k.Close()
return true, nil
}
func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10}
_ = cfg.removeDNSMatchPolicies()
}

View File

@@ -7,6 +7,7 @@ import (
) )
type ShutdownState struct { type ShutdownState struct {
CreatedKeys []string
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create host manager: %w", err) return fmt.Errorf("create host manager: %w", err)
} }
for _, key := range s.CreatedKeys {
manager.createdKeys[key] = struct{}{}
}
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }

View File

@@ -4,7 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync" "net/netip"
"os"
"strconv"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -12,18 +14,14 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
var (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
listenPort uint16 = 5353
listenPortMu sync.RWMutex
)
const ( const (
dnsTTL = 60 //seconds dnsTTL = 60
envServerPort = "NB_DNS_FORWARDER_PORT"
) )
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
@@ -36,28 +34,30 @@ type ForwarderEntry struct {
type Manager struct { type Manager struct {
firewall firewall.Manager firewall firewall.Manager
statusRecorder *peer.Status statusRecorder *peer.Status
localAddr netip.Addr
serverPort uint16
fwRules []firewall.Rule fwRules []firewall.Rule
tcpRules []firewall.Rule tcpRules []firewall.Rule
dnsForwarder *DNSForwarder dnsForwarder *DNSForwarder
} }
func ListenPort() uint16 { func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
listenPortMu.RLock() serverPort := nbdns.ForwarderServerPort
defer listenPortMu.RUnlock() if envPort := os.Getenv(envServerPort); envPort != "" {
return listenPort if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
} serverPort = uint16(port)
log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort)
} else {
log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort)
}
}
func SetListenPort(port uint16) {
listenPortMu.Lock()
listenPort = port
listenPortMu.Unlock()
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
return &Manager{ return &Manager{
firewall: fw, firewall: fw,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
localAddr: localAddr,
serverPort: serverPort,
} }
} }
@@ -71,7 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err return err
} }
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) if m.localAddr.IsValid() && m.firewall != nil {
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
} else {
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
}
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
} else {
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
}
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
go func() { go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil { if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists // todo handle close error if it is exists
@@ -96,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error {
} }
var mErr *multierror.Error var mErr *multierror.Error
if m.localAddr.IsValid() && m.firewall != nil {
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
}
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
}
if err := m.dropDNSFirewall(); err != nil { if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err) mErr = multierror.Append(mErr, err)
} }
@@ -111,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error {
func (m *Manager) allowDNSFirewall() error { func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{ dport := &firewall.Port{
IsRange: false, IsRange: false,
Values: []uint16{ListenPort()}, Values: []uint16{m.serverPort},
} }
if m.firewall == nil { if m.firewall == nil {

View File

@@ -203,8 +203,7 @@ type Engine struct {
wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup wgIfaceMonitorWg sync.WaitGroup
// dns forwarder port probeStunTurn *relay.StunTurnProbe
dnsFwdPort uint16
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -247,7 +246,7 @@ func NewEngine(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
dnsFwdPort: dnsfwd.ListenPort(), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
} }
sm := profilemanager.NewServiceManager("") sm := profilemanager.NewServiceManager("")
@@ -1060,10 +1059,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
} }
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err) log.Errorf("failed to update dns server, err: %v", err)
} }
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1084,7 +1087,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules // Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
@@ -1208,10 +1211,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
if forwarderPort == 0 {
forwarderPort = nbdns.ForwarderClientPort
}
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(), ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0), CustomZones: make([]nbdns.CustomZone, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0),
ForwarderPort: forwarderPort,
} }
for _, zone := range protoDNSConfig.GetCustomZones() { for _, zone := range protoDNSConfig.GetCustomZones() {
@@ -1667,7 +1676,7 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states. // and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool { func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy() signalHealthy := e.signal.IsHealthy()
@@ -1699,8 +1708,12 @@ func (e *Engine) RunHealthProbes() bool {
} }
e.syncMsgMux.Unlock() e.syncMsgMux.Unlock()
var results []relay.ProbeResult
results := e.probeICE(stuns, turns) if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results) e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true relayHealthy := true
@@ -1717,13 +1730,6 @@ func (e *Engine) RunHealthProbes() bool {
return allHealthy return allHealthy
} }
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
)
}
// restartEngine restarts the engine by cancelling the client context // restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() { func (e *Engine) restartEngine() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
@@ -1843,16 +1849,11 @@ func (e *Engine) GetWgAddr() netip.Addr {
func (e *Engine) updateDNSForwarder( func (e *Engine) updateDNSForwarder(
enabled bool, enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry, fwdEntries []*dnsfwd.ForwarderEntry,
forwarderPort uint16,
) { ) {
if e.config.DisableServerRoutes { if e.config.DisableServerRoutes {
return return
} }
if forwarderPort > 0 {
dnsfwd.SetListenPort(forwarderPort)
}
if !enabled { if !enabled {
if e.dnsForwardMgr == nil { if e.dnsForwardMgr == nil {
return return
@@ -1864,20 +1865,17 @@ func (e *Engine) updateDNSForwarder(
} }
if len(fwdEntries) > 0 { if len(fwdEntries) > 0 {
switch { if e.dnsForwardMgr == nil {
case e.dnsForwardMgr == nil: localAddr := e.wgInterface.Address().IP
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err) log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
log.Infof("started domain router service with %d entries", len(fwdEntries))
case e.dnsFwdPort != forwarderPort:
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
e.restartDnsFwd(fwdEntries, forwarderPort)
e.dnsFwdPort = forwarderPort
default: log.Infof("started domain router service with %d entries", len(fwdEntries))
} else {
e.dnsForwardMgr.UpdateDomains(fwdEntries) e.dnsForwardMgr.UpdateDomains(fwdEntries)
} }
} else if e.dnsForwardMgr != nil { } else if e.dnsForwardMgr != nil {
@@ -1887,20 +1885,6 @@ func (e *Engine) updateDNSForwarder(
} }
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
}
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
// stop and start the forwarder to apply the new port
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
} }
func (e *Engine) GetNet() (*netstack.Net, error) { func (e *Engine) GetNet() (*netstack.Net, error) {

View File

@@ -10,10 +10,10 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/dns"
) )
type rcvChan chan *types.EventFields type rcvChan chan *types.EventFields
@@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection // check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { if !l.dnsCollection.Load() && event.Protocol == types.UDP &&
(event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) {
return false return false
} }

View File

@@ -2,6 +2,8 @@ package relay
import ( import (
"context" "context"
"crypto/sha256"
"errors"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@@ -15,6 +17,15 @@ import (
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const (
DefaultCacheTTL = 20 * time.Second
probeTimeout = 6 * time.Second
)
var (
ErrCheckInProgress = errors.New("probe check is already in progress")
)
// ProbeResult holds the info about the result of a relay probe request // ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct { type ProbeResult struct {
URI string URI string
@@ -22,8 +33,164 @@ type ProbeResult struct {
Addr string Addr string
} }
type StunTurnProbe struct {
cacheResults []ProbeResult
cacheTimestamp time.Time
cacheKey string
cacheTTL time.Duration
probeInProgress bool
probeDone chan struct{}
mu sync.Mutex
}
func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
return &StunTurnProbe{
cacheTTL: cacheTTL,
}
}
func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if p.probeInProgress {
doneChan := p.probeDone
p.mu.Unlock()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-doneChan:
return p.getCachedResults(cacheKey, stuns, turns)
}
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
p.mu.Unlock()
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
return p.getCachedResults(cacheKey, stuns, turns)
}
// ProbeAll probes all given servers asynchronously and returns the results
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if results := p.checkCache(cacheKey); results != nil {
p.mu.Unlock()
return results
}
if p.probeInProgress {
p.mu.Unlock()
return createErrorResults(stuns, turns)
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
log.Infof("started new probe for STUN, TURN servers")
go func() {
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
}()
p.mu.Unlock()
timer := time.NewTimer(1300 * time.Millisecond)
defer timer.Stop()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-probeDone:
// when the probe is return fast, return the results right away
return p.getCachedResults(cacheKey, stuns, turns)
case <-timer.C:
// if the probe takes longer than 1.3s, return error results to avoid blocking
return createErrorResults(stuns, turns)
}
}
func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
age := time.Since(p.cacheTimestamp)
if age < p.cacheTTL {
results := append([]ProbeResult(nil), p.cacheResults...)
log.Debugf("returning cached probe results (age: %v)", age)
return results
}
}
return nil
}
func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
p.mu.Lock()
defer p.mu.Unlock()
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
return append([]ProbeResult(nil), p.cacheResults...)
}
return createErrorResults(stuns, turns)
}
func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
defer func() {
p.mu.Lock()
p.probeInProgress = false
p.mu.Unlock()
}()
results := make([]ProbeResult, len(stuns)+len(turns))
var wg sync.WaitGroup
for i, uri := range stuns {
wg.Add(1)
go func(idx int, stunURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = stunURI.String()
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
}(i, uri)
}
stunOffset := len(stuns)
for i, uri := range turns {
wg.Add(1)
go func(idx int, turnURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = turnURI.String()
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
}(stunOffset+i, uri)
}
wg.Wait()
p.mu.Lock()
p.cacheResults = results
p.cacheTimestamp = time.Now()
p.cacheKey = cacheKey
p.mu.Unlock()
log.Debug("Stored new probe results in cache")
}
// ProbeSTUN tries binding to the given STUN uri and acquiring an address // ProbeSTUN tries binding to the given STUN uri and acquiring an address
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() { defer func() {
if probeErr != nil { if probeErr != nil {
log.Debugf("stun probe error from %s: %s", uri, probeErr) log.Debugf("stun probe error from %s: %s", uri, probeErr)
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
} }
// ProbeTURN tries allocating a session from the given TURN URI // ProbeTURN tries allocating a session from the given TURN URI
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() { defer func() {
if probeErr != nil { if probeErr != nil {
log.Debugf("turn probe error from %s: %s", uri, probeErr) log.Debugf("turn probe error from %s: %s", uri, probeErr)
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
return relayConn.LocalAddr().String(), nil return relayConn.LocalAddr().String(), nil
} }
// ProbeAll probes all given servers asynchronously and returns the results func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
func ProbeAll( total := len(stuns) + len(turns)
ctx context.Context, results := make([]ProbeResult, total)
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
relays []*stun.URI,
) []ProbeResult {
results := make([]ProbeResult, len(relays))
var wg sync.WaitGroup allURIs := append(append([]*stun.URI{}, stuns...), turns...)
for i, uri := range relays { for i, uri := range allURIs {
ctx, cancel := context.WithTimeout(ctx, 6*time.Second) results[i] = ProbeResult{
defer cancel() URI: uri.String(),
Err: ErrCheckInProgress,
wg.Add(1) }
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
} }
wg.Wait()
return results return results
} }
func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
h := sha256.New()
for _, uri := range stuns {
h.Write([]byte(uri.String()))
}
for _, uri := range turns {
h.Write([]byte(uri.String()))
}
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -1,6 +1,7 @@
package common package common
import ( import (
"sync/atomic"
"time" "time"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
@@ -25,4 +26,5 @@ type HandlerParams struct {
UseNewDNSRoute bool UseNewDNSRoute bool
Firewall manager.Manager Firewall manager.Manager
FakeIPManager *fakeip.Manager FakeIPManager *fakeip.Manager
ForwarderPort *atomic.Uint32
} }

View File

@@ -8,6 +8,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@@ -18,7 +19,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/common"
@@ -55,6 +55,7 @@ type DnsInterceptor struct {
peerStore *peerstore.Store peerStore *peerstore.Store
firewall firewall.Manager firewall firewall.Manager
fakeIPManager *fakeip.Manager fakeIPManager *fakeip.Manager
forwarderPort *atomic.Uint32
} }
func New(params common.HandlerParams) *DnsInterceptor { func New(params common.HandlerParams) *DnsInterceptor {
@@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor {
firewall: params.Firewall, firewall: params.Firewall,
fakeIPManager: params.FakeIPManager, fakeIPManager: params.FakeIPManager,
interceptedDomains: make(domainMap), interceptedDomains: make(domainMap),
forwarderPort: params.ForwarderPort,
} }
} }
@@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel() defer cancel()

View File

@@ -10,6 +10,7 @@ import (
"runtime" "runtime"
"slices" "slices"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -23,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client" "github.com/netbirdio/netbird/client/internal/routemanager/client"
@@ -54,6 +56,7 @@ type Manager interface {
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
SetFirewall(firewall.Manager) error SetFirewall(firewall.Manager) error
SetDNSForwarderPort(port uint16)
Stop(stateManager *statemanager.Manager) Stop(stateManager *statemanager.Manager)
} }
@@ -101,12 +104,13 @@ type DefaultManager struct {
disableServerRoutes bool disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler activeRoutes map[route.HAUniqueID]client.RouteHandler
fakeIPManager *fakeip.Manager fakeIPManager *fakeip.Manager
dnsForwarderPort atomic.Uint32
} }
func NewManager(config ManagerConfig) *DefaultManager { func NewManager(config ManagerConfig) *DefaultManager {
mCTX, cancel := context.WithCancel(config.Context) mCTX, cancel := context.WithCancel(config.Context)
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier) sysOps := systemops.New(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil { if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name()) nbnet.SetVPNInterfaceName(config.WGInterface.Name())
@@ -130,6 +134,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
disableServerRoutes: config.DisableServerRoutes, disableServerRoutes: config.DisableServerRoutes,
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
} }
dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort))
useNoop := netstack.IsEnabled() || config.DisableClientRoutes useNoop := netstack.IsEnabled() || config.DisableClientRoutes
dm.setupRefCounters(useNoop) dm.setupRefCounters(useNoop)
@@ -270,6 +275,11 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
return nil return nil
} }
// SetDNSForwarderPort sets the DNS forwarder port for route handlers
func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
m.dnsForwarderPort.Store(uint32(port))
}
// Stop stops the manager watchers and clean firewall rules // Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop() m.stop()
@@ -345,6 +355,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
UseNewDNSRoute: m.useNewDNSRoute, UseNewDNSRoute: m.useNewDNSRoute,
Firewall: m.firewall, Firewall: m.firewall,
FakeIPManager: m.fakeIPManager, FakeIPManager: m.fakeIPManager,
ForwarderPort: &m.dnsForwarderPort,
} }
handler := client.HandlerFromRoute(params) handler := client.HandlerFromRoute(params)
if err := handler.AddRoute(m.ctx); err != nil { if err := handler.AddRoute(m.ctx); err != nil {

View File

@@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error {
panic("implement me") panic("implement me")
} }
// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface
func (m *MockManager) SetDNSForwarderPort(port uint16) {
}
// Stop mock implementation of Stop from Manager interface // Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop(stateManager *statemanager.Manager) { func (m *MockManager) Stop(stateManager *statemanager.Manager) {
if m.StopFunc != nil { if m.StopFunc != nil {

View File

@@ -0,0 +1,8 @@
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
package systemops
// FlushMarkedRoutes is a no-op on non-BSD platforms.
func (r *SysOps) FlushMarkedRoutes() error {
return nil
}

View File

@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
sysops := NewSysOps(nil, nil) sysOps := New(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
sysops.refCounter.LoadData((*ExclusionCounter)(s)) sysOps.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush() return sysOps.refCounter.Flush()
} }
func (s *ShutdownState) MarshalJSON() ([]byte, error) { func (s *ShutdownState) MarshalJSON() ([]byte, error) {

View File

@@ -83,7 +83,7 @@ type SysOps struct {
localSubnetsCacheTime time.Time localSubnetsCacheTime time.Time
} }
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier, notifier: notifier,

View File

@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
_, intf = setupDummyInterface(t) _, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf} nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil) r := New(nil, nil)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
nexthop := Nexthop{netip.Addr{}, netIntf} nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil) r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop) err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table") require.NoError(t, err, "Failed to add route to table")

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err) require.NoError(t, err)
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err) require.NoError(t, err)
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close()) assert.NoError(t, wgInterface.Close())
}) })
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")

View File

@@ -7,19 +7,39 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"strconv" "strconv"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/route" "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
)
var routeProtoFlag int
func init() {
switch os.Getenv(envRouteProtoFlag) {
case "2":
routeProtoFlag = unix.RTF_PROTO2
case "3":
routeProtoFlag = unix.RTF_PROTO3
default:
routeProtoFlag = unix.RTF_PROTO1
}
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
@@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
return r.cleanupRefCounter(stateManager) return r.cleanupRefCounter(stateManager)
} }
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
func (r *SysOps) FlushMarkedRoutes() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("Skipping route flush: %v", err)
continue
}
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
continue
}
nexthop := Nexthop{
IP: routeInfo.Gw,
Intf: routeInfo.Interface,
}
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
continue
}
flushedCount++
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
}
if flushedCount > 0 {
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_ADD, prefix, nexthop) return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
} }
@@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{ msg = &route.RouteMessage{
Type: action, Type: action,
Flags: unix.RTF_UP, Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Seq: r.getSeq(), Seq: r.getSeq(),
} }

View File

@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
data, err := os.ReadFile(m.filePath) data, err := os.ReadFile(m.filePath)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist") log.Debugf("state file %s does not exist", m.filePath)
return nil, nil // nolint:nilnil return nil, nil // nolint:nilnil
} }
return nil, fmt.Errorf("read state file: %w", err) return nil, fmt.Errorf("read state file: %w", err)

View File

@@ -0,0 +1,59 @@
package winregistry
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows/registry"
)
var (
advapi = syscall.NewLazyDLL("advapi32.dll")
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
)
const (
// Registry key options
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
// Registry disposition values
regCreatedNewKey = 0x1
regOpenedExistingKey = 0x2
)
// CreateVolatileKey creates a volatile registry key named path under open key root.
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
// The access parameter specifies the access rights for the key to be created.
//
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
// This provides automatic cleanup without requiring manual registry maintenance.
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
pathPtr, err := syscall.UTF16PtrFromString(path)
if err != nil {
return 0, false, err
}
var (
handle syscall.Handle
disposition uint32
)
ret, _, _ := regCreateKeyExW.Call(
uintptr(root),
uintptr(unsafe.Pointer(pathPtr)),
0, // reserved
0, // class
uintptr(regOptionVolatile), // options - volatile key
uintptr(access), // desired access
0, // security attributes
uintptr(unsafe.Pointer(&handle)),
uintptr(unsafe.Pointer(&disposition)),
)
if ret != 0 {
return 0, false, syscall.Errno(ret)
}
return registry.Key(handle), disposition == regOpenedExistingKey, nil
}

View File

@@ -17,8 +17,7 @@ type Conn struct {
ID hooks.ConnectionID ID hooks.ConnectionID
} }
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection // Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
func (c *Conn) Close() error { func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn) return closeConn(c.ID, c.Conn)
} }
@@ -29,7 +28,7 @@ type TCPConn struct {
ID hooks.ConnectionID ID hooks.ConnectionID
} }
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. // Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
func (c *TCPConn) Close() error { func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn) return closeConn(c.ID, c.TCPConn)
} }
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
// closeConn is a helper function to close connections and execute close hooks. // closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error { func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close() err := conn.Close()
cleanupConnID(id)
return err
}
// cleanupConnID executes close hooks for a connection ID.
func cleanupConnID(id hooks.ConnectionID) {
closeHooks := hooks.GetCloseHooks() closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks { for _, hook := range closeHooks {
if err := hook(id); err != nil { if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err) log.Errorf("Error executing close hook: %v", err)
} }
} }
return err
} }

View File

@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
} }
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
} }
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err) log.Errorf("failed to close connection: %v", err)
} }

View File

@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
cleanupConnID(connID)
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
} }
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
ips, err := resolver.LookupIPAddr(ctx, host) ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil { if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err) return fmt.Errorf("resolve address %s: %w", address, err)
} }
log.Debugf("Dialer resolved IPs for %s: %v", address, ips) log.Debugf("Dialer resolved IPs for %s: %v", address, ips)

View File

@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.PacketConn.WriteTo(b, addr) return c.PacketConn.WriteTo(b, addr)
} }
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. // Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
func (c *PacketConn) Close() error { func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear() defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn) return closeConn(c.ID, c.PacketConn)
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.UDPConn.WriteTo(b, addr) return c.UDPConn.WriteTo(b, addr)
} }
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. // Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
func (c *UDPConn) Close() error { func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear() defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn) return closeConn(c.ID, c.UDPConn)

View File

@@ -1057,10 +1057,7 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus { if msg.GetFullPeerStatus {
if msg.ShouldRunProbes { s.runProbes(msg.ShouldRunProbes)
s.runProbes()
}
fullStatus := s.statusRecorder.GetFullStatus() fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory() pbFullStatus.Events = s.statusRecorder.GetEventHistory()
@@ -1070,7 +1067,7 @@ func (s *Server) Status(
return &statusResponse, nil return &statusResponse, nil
} }
func (s *Server) runProbes() { func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil { if s.connectClient == nil {
return return
} }
@@ -1081,7 +1078,7 @@ func (s *Server) runProbes() {
} }
if time.Since(s.lastProbe) > probeThreshold { if time.Since(s.lastProbe) > probeThreshold {
if engine.RunHealthProbes() { if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now() s.lastProbe = time.Now()
} }
} }

View File

@@ -10,7 +10,9 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
} }
// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
@@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
for _, relay := range overview.Relays.Details { for _, relay := range overview.Relays.Details {
available := "Available" available := "Available"
reason := "" reason := ""
if !relay.Available { if !relay.Available {
available = "Unavailable" if relay.Error == probeRelay.ErrCheckInProgress.Error() {
reason = fmt.Sprintf(", reason: %s", relay.Error) available = "Checking..."
} else {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
} }
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
} }
} else { } else {

View File

@@ -296,6 +296,8 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem
logFile string logFile string
wLoginURL fyne.Window wLoginURL fyne.Window
connectCancel context.CancelFunc
} }
type menuHandler struct { type menuHandler struct {
@@ -592,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
} }
} }
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) return nil, fmt.Errorf("get daemon client: %w", err)
return nil, err
} }
activeProf, err := s.profileManager.GetActiveProfile() activeProf, err := s.profileManager.GetActiveProfile()
if err != nil { if err != nil {
log.Errorf("get active profile: %v", err) return nil, fmt.Errorf("get active profile: %w", err)
return nil, err
} }
currUser, err := user.Current() currUser, err := user.Current()
@@ -610,84 +610,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
return nil, fmt.Errorf("get current user: %w", err) return nil, fmt.Errorf("get current user: %w", err)
} }
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ loginResp, err := conn.Login(ctx, &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name, ProfileName: &activeProf.Name,
Username: &currUser.Username, Username: &currUser.Username,
}) })
if err != nil { if err != nil {
log.Errorf("login to management URL with: %v", err) return nil, fmt.Errorf("login to management: %w", err)
return nil, err
} }
if loginResp.NeedsSSOLogin && openURL { if loginResp.NeedsSSOLogin && openURL {
err = s.handleSSOLogin(loginResp, conn) if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil {
if err != nil { return nil, fmt.Errorf("SSO login: %w", err)
log.Errorf("handle SSO login failed: %v", err)
return nil, err
} }
} }
return loginResp, nil return loginResp, nil
} }
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
err := openURL(loginResp.VerificationURIComplete) if err := openURL(loginResp.VerificationURIComplete); err != nil {
if err != nil { return fmt.Errorf("open browser: %w", err)
log.Errorf("opening the verification uri in the browser failed: %v", err)
return err
} }
resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil { if err != nil {
log.Errorf("waiting sso login failed with: %v", err) return fmt.Errorf("wait for SSO login: %w", err)
return err
} }
if resp.Email != "" { if resp.Email != "" {
err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
Email: resp.Email, Email: resp.Email,
}) }); err != nil {
if err != nil { log.Debugf("failed to set profile state: %v", err)
log.Warnf("failed to set profile state: %v", err)
} else { } else {
s.mProfile.refresh() s.mProfile.refresh()
} }
} }
return nil return nil
} }
func (s *serviceClient) menuUpClick() error { func (s *serviceClient) menuUpClick(ctx context.Context) error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
systray.SetTemplateIcon(iconErrorMacOS, s.icError) systray.SetTemplateIcon(iconErrorMacOS, s.icError)
log.Errorf("get client: %v", err) return fmt.Errorf("get daemon client: %w", err)
return err
} }
_, err = s.login(true) _, err = s.login(ctx, true)
if err != nil { if err != nil {
log.Errorf("login failed with: %v", err) return fmt.Errorf("login: %w", err)
return err
} }
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) return fmt.Errorf("get status: %w", err)
return err
} }
if status.Status == string(internal.StatusConnected) { if status.Status == string(internal.StatusConnected) {
log.Warnf("already connected")
return nil return nil
} }
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
log.Errorf("up service: %v", err) return fmt.Errorf("start connection: %w", err)
return err
} }
return nil return nil
@@ -697,24 +684,20 @@ func (s *serviceClient) menuDownClick() error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) return fmt.Errorf("get daemon client: %w", err)
return err
} }
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) return fmt.Errorf("get status: %w", err)
return err
} }
if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) {
log.Warnf("already down")
return nil return nil
} }
if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("down service: %v", err) return fmt.Errorf("stop connection: %w", err)
return err
} }
return nil return nil
@@ -850,6 +833,7 @@ func (s *serviceClient) onTrayReady() {
newProfileMenuArgs := &newProfileMenuArgs{ newProfileMenuArgs := &newProfileMenuArgs{
ctx: s.ctx, ctx: s.ctx,
serviceClient: s,
profileManager: s.profileManager, profileManager: s.profileManager,
eventHandler: s.eventHandler, eventHandler: s.eventHandler,
profileMenuItem: profileMenuItem, profileMenuItem: profileMenuItem,
@@ -1381,7 +1365,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return return
} }
resp, err := s.login(false) resp, err := s.login(ctx, false)
if err != nil { if err != nil {
log.Errorf("failed to fetch login URL: %v", err) log.Errorf("failed to fetch login URL: %v", err)
return return
@@ -1401,7 +1385,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return return
} }
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
if err != nil { if err != nil {
log.Errorf("Waiting sso login failed with: %v", err) log.Errorf("Waiting sso login failed with: %v", err)
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
@@ -1409,7 +1393,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
} }
label.SetText("Re-authentication successful.\nReconnecting") label.SetText("Re-authentication successful.\nReconnecting")
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) log.Errorf("get service status: %v", err)
return return
@@ -1422,7 +1406,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return return
} }
_, err = conn.Up(s.ctx, &proto.UpRequest{}) _, err = conn.Up(ctx, &proto.UpRequest{})
if err != nil { if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
log.Errorf("Reconnecting failed with: %v", err) log.Errorf("Reconnecting failed with: %v", err)

View File

@@ -18,6 +18,7 @@ import (
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types" uptypes "github.com/netbirdio/netbird/upload-server/types"
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
return "", err return "", err
} }
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
log.Warnf("Failed to get post-up status: %v", err) log.Warnf("Failed to get post-up status: %v", err)
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err) return nil, fmt.Errorf("get client: %v", err)
} }
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err) log.Warnf("failed to get status for debug bundle: %v", err)
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }

View File

@@ -12,6 +12,8 @@ import (
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/systray" "fyne.io/systray"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
@@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) {
func (h *eventHandler) handleConnectClick() { func (h *eventHandler) handleConnectClick() {
h.client.mUp.Disable() h.client.mUp.Disable()
if h.client.connectCancel != nil {
h.client.connectCancel()
}
connectCtx, connectCancel := context.WithCancel(h.client.ctx)
h.client.connectCancel = connectCancel
go func() { go func() {
defer h.client.mUp.Enable() defer connectCancel()
if err := h.client.menuUpClick(); err != nil {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) if err := h.client.menuUpClick(connectCtx); err != nil {
st, ok := status.FromError(err)
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
log.Debugf("connect operation cancelled by user")
} else {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
log.Errorf("connect failed: %v", err)
}
}
if err := h.client.updateStatus(); err != nil {
log.Debugf("failed to update status after connect: %v", err)
} }
}() }()
} }
func (h *eventHandler) handleDisconnectClick() { func (h *eventHandler) handleDisconnectClick() {
h.client.mDown.Disable() h.client.mDown.Disable()
if h.client.connectCancel != nil {
log.Debugf("cancelling ongoing connect operation")
h.client.connectCancel()
h.client.connectCancel = nil
}
go func() { go func() {
defer h.client.mDown.Enable()
if err := h.client.menuDownClick(); err != nil { if err := h.client.menuDownClick(); err != nil {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon")) st, ok := status.FromError(err)
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
log.Errorf("disconnect failed: %v", err)
} else {
log.Debugf("disconnect cancelled or already disconnecting")
}
}
if err := h.client.updateStatus(); err != nil {
log.Debugf("failed to update status after disconnect: %v", err)
} }
}() }()
} }
@@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
} }
h.client.getSrvConfig() h.client.getSrvConfig()
return nil return nil
} }

View File

@@ -387,6 +387,7 @@ type subItem struct {
type profileMenu struct { type profileMenu struct {
mu sync.Mutex mu sync.Mutex
ctx context.Context ctx context.Context
serviceClient *serviceClient
profileManager *profilemanager.ProfileManager profileManager *profilemanager.ProfileManager
eventHandler *eventHandler eventHandler *eventHandler
profileMenuItem *systray.MenuItem profileMenuItem *systray.MenuItem
@@ -396,7 +397,7 @@ type profileMenu struct {
logoutSubItem *subItem logoutSubItem *subItem
profilesState []Profile profilesState []Profile
downClickCallback func() error downClickCallback func() error
upClickCallback func() error upClickCallback func(context.Context) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func() loadSettingsCallback func()
app fyne.App app fyne.App
@@ -404,12 +405,13 @@ type profileMenu struct {
type newProfileMenuArgs struct { type newProfileMenuArgs struct {
ctx context.Context ctx context.Context
serviceClient *serviceClient
profileManager *profilemanager.ProfileManager profileManager *profilemanager.ProfileManager
eventHandler *eventHandler eventHandler *eventHandler
profileMenuItem *systray.MenuItem profileMenuItem *systray.MenuItem
emailMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem
downClickCallback func() error downClickCallback func() error
upClickCallback func() error upClickCallback func(context.Context) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func() loadSettingsCallback func()
app fyne.App app fyne.App
@@ -418,6 +420,7 @@ type newProfileMenuArgs struct {
func newProfileMenu(args newProfileMenuArgs) *profileMenu { func newProfileMenu(args newProfileMenuArgs) *profileMenu {
p := profileMenu{ p := profileMenu{
ctx: args.ctx, ctx: args.ctx,
serviceClient: args.serviceClient,
profileManager: args.profileManager, profileManager: args.profileManager,
eventHandler: args.eventHandler, eventHandler: args.eventHandler,
profileMenuItem: args.profileMenuItem, profileMenuItem: args.profileMenuItem,
@@ -569,10 +572,19 @@ func (p *profileMenu) refresh() {
} }
} }
if err := p.upClickCallback(); err != nil { if p.serviceClient.connectCancel != nil {
p.serviceClient.connectCancel()
}
connectCtx, connectCancel := context.WithCancel(p.ctx)
p.serviceClient.connectCancel = connectCancel
if err := p.upClickCallback(connectCtx); err != nil {
log.Errorf("failed to handle up click after switching profile: %v", err) log.Errorf("failed to handle up click after switching profile: %v", err)
} }
connectCancel()
p.refresh() p.refresh()
p.loadSettingsCallback() p.loadSettingsCallback()
} }

View File

@@ -19,6 +19,10 @@ const (
RootZone = "." RootZone = "."
// DefaultClass is the class supported by the system // DefaultClass is the class supported by the system
DefaultClass = "IN" DefaultClass = "IN"
// ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort.
ForwarderClientPort uint16 = 5353
// ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here.
ForwarderServerPort uint16 = 22054
) )
const invalidHostLabel = "[^a-zA-Z0-9-]+" const invalidHostLabel = "[^a-zA-Z0-9-]+"
@@ -31,6 +35,8 @@ type Config struct {
NameServerGroups []*NameServerGroup NameServerGroups []*NameServerGroup
// CustomZones contains a list of custom zone // CustomZones contains a list of custom zone
CustomZones []CustomZone CustomZones []CustomZone
// ForwarderPort is the port clients should connect to on routing peers for DNS forwarding
ForwarderPort uint16
} }
// CustomZone represents a custom zone to be resolved by the dns server // CustomZone represents a custom zone to be resolved by the dns server

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
echo "" echo ""
export NETBIRD_SIGNAL_PROTOCOL="https"
unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_LETSENCRYPT_DOMAIN
unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_FILE
unset NETBIRD_MGMT_API_CERT_KEY_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE
fi fi
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
export NETBIRD_SIGNAL_PROTOCOL="https"
fi
# Check if management identity provider is set # Check if management identity provider is set
if [ -n "$NETBIRD_MGMT_IDP" ]; then if [ -n "$NETBIRD_MGMT_IDP" ]; then
EXTRA_CONFIG={} EXTRA_CONFIG={}

View File

@@ -40,13 +40,21 @@ services:
signal: signal:
<<: *default <<: *default
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
depends_on:
- dashboard
volumes: volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird - $SIGNAL_VOLUMENAME:/var/lib/netbird
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
ports: ports:
- $NETBIRD_SIGNAL_PORT:80 - $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation # # port and command for Let's Encrypt validation
# - 443:443 # - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
command: [
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
"--log-file", "console"
]
# Relay # Relay
relay: relay:

View File

@@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
func (s *BaseServer) PermissionsManager() permissions.Manager { func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager { return Create(s, func() permissions.Manager {
return integrations.InitPermissionsManager(s.Store()) manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
s.AfterInit(func(s *BaseServer) {
manager.SetAccountManager(s.AccountManager())
})
return manager
}) })
} }

View File

@@ -109,7 +109,7 @@ type Manager interface {
GetIdpManager() idp.Manager GetIdpManager() idp.Manager
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error

View File

@@ -21,8 +21,8 @@ import (
) )
const ( const (
dnsForwarderPort = 22054 dnsForwarderPort = nbdns.ForwarderServerPort
oldForwarderPort = 5353 oldForwarderPort = nbdns.ForwarderClientPort
) )
const dnsForwarderPortMinVersion = "v0.59.0" const dnsForwarderPortMinVersion = "v0.59.0"
@@ -196,7 +196,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. // If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
if len(peers) == 0 { if len(peers) == 0 {
return oldForwarderPort return int64(oldForwarderPort)
} }
reqVer := semver.Canonical(requiredVersion) reqVer := semver.Canonical(requiredVersion)
@@ -211,17 +211,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" { if peerVersion == "" {
// If any peer doesn't have version info, return 0 // If any peer doesn't have version info, return 0
return oldForwarderPort return int64(oldForwarderPort)
} }
// Compare versions // Compare versions
if semver.Compare(peerVersion, reqVer) < 0 { if semver.Compare(peerVersion, reqVer) < 0 {
return oldForwarderPort return int64(oldForwarderPort)
} }
} }
// All peers have the required version or newer // All peers have the required version or newer
return dnsForwarderPort return int64(dnsForwarderPort)
} }
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache

View File

@@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, dnsForwarderPort) toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
} }
}) })
@@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{} cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache, dnsForwarderPort) toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
} }
}) })
} }
@@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
} }
// First run with config1 // First run with config1
result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Second run with config2 // Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
// Third run with config1 again // Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Verify that result1 and result3 are identical // Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) { if !reflect.DeepEqual(result1, result3) {
@@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) {
// Test with empty peers list // Test with empty peers list
peers := []*nbpeer.Peer{} peers := []*nbpeer.Peer{}
result := computeForwarderPort(peers, "v0.59.0") result := computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort { if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
} }
@@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) {
}, },
} }
result = computeForwarderPort(peers, "v0.59.0") result = computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort { if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
} }
@@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) {
}, },
} }
result = computeForwarderPort(peers, "v0.59.0") result = computeForwarderPort(peers, "v0.59.0")
if result != dnsForwarderPort { if result != int64(dnsForwarderPort) {
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
} }
@@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) {
}, },
} }
result = computeForwarderPort(peers, "v0.59.0") result = computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort { if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
} }
@@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) {
}, },
} }
result = computeForwarderPort(peers, "v0.59.0") result = computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort { if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
} }
@@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) {
}, },
} }
result = computeForwarderPort(peers, "v0.59.0") result = computeForwarderPort(peers, "v0.59.0")
if result == oldForwarderPort { if result == int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
} }
@@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) {
}, },
} }
result = computeForwarderPort(peers, "v0.59.0") result = computeForwarderPort(peers, "v0.59.0")
if result != oldForwarderPort { if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
} }
} }

View File

@@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
@@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
} }
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) reason := invalidPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
} }
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
} }
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
} }
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return return
} }
h.setApprovalRequiredFlag(respBody, validPeersMap) h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap)
util.WriteJSONObject(r.Context(), w, respBody) util.WriteJSONObject(r.Context(), w, respBody)
} }
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) {
for _, peer := range respBody { for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id] _, ok := validPeersMap[peer.Id]
if !ok { if !ok {
peer.ApprovalRequired = true peer.ApprovalRequired = true
reason := invalidPeersMap[peer.Id]
peer.DisapprovalReason = &reason
} }
} }
} }
@@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
} }
} }
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
} }
} }
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer {
osVersion := peer.Meta.OSVersion osVersion := peer.Meta.OSVersion
if osVersion == "" { if osVersion == "" {
osVersion = peer.Meta.Core osVersion = peer.Meta.Core
} }
return &api.Peer{ apiPeer := &api.Peer{
CreatedAt: peer.CreatedAt, CreatedAt: peer.CreatedAt,
Id: peer.ID, Id: peer.ID,
Name: peer.Name, Name: peer.Name,
@@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
InactivityExpirationEnabled: peer.InactivityExpirationEnabled, InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral, Ephemeral: peer.Ephemeral,
} }
if !approved {
apiPeer.DisapprovalReason = &reason
}
return apiPeer
} }
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {

View File

@@ -7,9 +7,10 @@ import (
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"

View File

@@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
return true, nil return true, nil
} }
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
var err error var err error
var groups []*types.Group var groups []*types.Group
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
@@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
if err != nil {
return nil, nil, err
}
invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra)
if err != nil {
return nil, nil, err
}
return validPeers, invalidPeers, nil
} }
type MockIntegratedValidator struct { type MockIntegratedValidator struct {
@@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil return validatedPeers, nil
} }
func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) {
return make(map[string]string), nil
}
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
return peer return peer
} }

View File

@@ -15,6 +15,7 @@ type IntegratedValidator interface {
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
Stop(ctx context.Context) Stop(ctx context.Context)

View File

@@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me") panic("implement me")
} }
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
account, err := am.GetAccountFunc(ctx, accountID) account, err := am.GetAccountFunc(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
approvedPeers := make(map[string]struct{}) approvedPeers := make(map[string]struct{})
for id := range account.Peers { for id := range account.Peers {
approvedPeers[id] = struct{}{} approvedPeers[id] = struct{}{}
} }
return approvedPeers, nil return approvedPeers, nil, nil
} }
// GetGroup mock implementation of GetGroup from server.AccountManager interface // GetGroup mock implementation of GetGroup from server.AccountManager interface

View File

@@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) {
} }
dnsCache := &DNSConfigCache{} dnsCache := &DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
assert.NotNil(t, response) assert.NotNil(t, response)
// assert peer config // assert peer config

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -22,6 +23,7 @@ type Manager interface {
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
SetAccountManager(accountManager account.Manager)
} }
type managerImpl struct { type managerImpl struct {
@@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
return permissions, nil return permissions, nil
} }
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
// no-op
}

View File

@@ -9,6 +9,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
account "github.com/netbirdio/netbird/management/server/account"
modules "github.com/netbirdio/netbird/management/server/permissions/modules" modules "github.com/netbirdio/netbird/management/server/permissions/modules"
operations "github.com/netbirdio/netbird/management/server/permissions/operations" operations "github.com/netbirdio/netbird/management/server/permissions/operations"
roles "github.com/netbirdio/netbird/management/server/permissions/roles" roles "github.com/netbirdio/netbird/management/server/permissions/roles"
@@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
} }
// SetAccountManager mocks base method.
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetAccountManager", accountManager)
}
// SetAccountManager indicates an expected call of SetAccountManager.
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
}
// ValidateAccountAccess mocks base method. // ValidateAccountAccess mocks base method.
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus { if dnsManagementStatus {
var zones []nbdns.CustomZone var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" { if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{ zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain, Domain: peersCustomZone.Domain,
Records: records, Records: records,
@@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool {
} }
// filterZoneRecordsForPeers filters DNS records to only include peers to connect. // filterZoneRecordsForPeers filters DNS records to only include peers to connect.
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
peerIPs := make(map[string]struct{}) peerIPs := make(map[string]struct{})
@@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
peerIPs[peerToConnect.IP.String()] = struct{}{} peerIPs[peerToConnect.IP.String()] = struct{}{}
} }
for _, expiredPeer := range expiredPeers {
peerIPs[expiredPeer.IP.String()] = struct{}{}
}
for _, record := range customZone.Records { for _, record := range customZone.Records {
if _, exists := peerIPs[record.RData]; exists { if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record) filteredRecords = append(filteredRecords, record)

View File

@@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
peer *nbpeer.Peer peer *nbpeer.Peer
customZone nbdns.CustomZone customZone nbdns.CustomZone
peersToConnect []*nbpeer.Peer peersToConnect []*nbpeer.Peer
expiredPeers []*nbpeer.Peer
expectedRecords []nbdns.SimpleRecord expectedRecords []nbdns.SimpleRecord
}{ }{
{ {
@@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
}, },
}, },
peersToConnect: []*nbpeer.Peer{}, peersToConnect: []*nbpeer.Peer{},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{ expectedRecords: []nbdns.SimpleRecord{
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
@@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
} }
return peers return peers
}(), }(),
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: func() []nbdns.SimpleRecord { expectedRecords: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord var records []nbdns.SimpleRecord
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
@@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
}, },
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{ expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
@@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
}, },
}, },
{
name: "expired peers are included in DNS entries",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
peersToConnect: []*nbpeer.Peer{
{ID: "peer1", IP: net.ParseIP("10.0.0.1")},
},
expiredPeers: []*nbpeer.Peer{
{ID: "expired-peer", IP: net.ParseIP("10.0.0.99")},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers)
assert.Equal(t, len(tt.expectedRecords), len(result)) assert.Equal(t, len(tt.expectedRecords), len(result))
assert.ElementsMatch(t, tt.expectedRecords, result) assert.ElementsMatch(t, tt.expectedRecords, result)
}) })

View File

@@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then
NETBIRD_RELEASE=latest NETBIRD_RELEASE=latest
fi fi
TAG_NAME=""
get_release() { get_release() {
local RELEASE=$1 local RELEASE=$1
if [ "$RELEASE" = "latest" ]; then if [ "$RELEASE" = "latest" ]; then
@@ -38,17 +40,19 @@ get_release() {
local TAG="tags/${RELEASE}" local TAG="tags/${RELEASE}"
local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
fi fi
OUTPUT=""
if [ -n "$GITHUB_TOKEN" ]; then if [ -n "$GITHUB_TOKEN" ]; then
curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}")
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
else else
curl -s "${URL}" \ OUTPUT=$(curl -s "${URL}")
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
fi fi
TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1)
echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+'
} }
download_release_binary() { download_release_binary() {
VERSION=$(get_release "$NETBIRD_RELEASE") VERSION=$(get_release "$NETBIRD_RELEASE")
echo "Using the following tag name for binary installation: ${TAG_NAME}"
BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download"
BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz"

View File

@@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
var err error var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
if err != nil { if err != nil {
log.Printf("createConnection error: %v", err) return fmt.Errorf("create connection: %w", err)
return err
} }
return nil return nil
} }

View File

@@ -463,6 +463,9 @@ components:
description: (Cloud only) Indicates whether peer needs approval description: (Cloud only) Indicates whether peer needs approval
type: boolean type: boolean
example: true example: true
disapproval_reason:
description: (Cloud only) Reason why the peer requires approval
type: string
country_code: country_code:
$ref: '#/components/schemas/CountryCode' $ref: '#/components/schemas/CountryCode'
city_name: city_name:

View File

@@ -1037,6 +1037,9 @@ type Peer struct {
// CreatedAt Peer creation date (UTC) // CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
// DisapprovalReason (Cloud only) Reason why the peer requires approval
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
@@ -1124,6 +1127,9 @@ type PeerBatch struct {
// CreatedAt Peer creation date (UTC) // CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
// DisapprovalReason (Cloud only) Reason why the peer requires approval
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`

View File

@@ -410,7 +410,7 @@ message DNSConfig {
bool ServiceEnable = 1; bool ServiceEnable = 1;
repeated NameServerGroup NameServerGroups = 2; repeated NameServerGroup NameServerGroups = 2;
repeated CustomZone CustomZones = 3; repeated CustomZone CustomZones = 3;
int64 ForwarderPort = 4; int64 ForwarderPort = 4 [deprecated = true];
} }
// CustomZone represents a dns.CustomZone // CustomZone represents a dns.CustomZone

View File

@@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
var err error var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
if err != nil { if err != nil {
log.Printf("createConnection error: %v", err) return fmt.Errorf("create connection: %w", err)
return err
} }
return nil return nil
} }

View File

@@ -94,7 +94,7 @@ var (
startPprof() startPprof()
opts, certManager, err := getTLSConfigurations() opts, certManager, tlsConfig, err := getTLSConfigurations()
if err != nil { if err != nil {
return err return err
} }
@@ -132,7 +132,7 @@ var (
// Start the main server - always serve HTTP with WebSocket proxy support // Start the main server - always serve HTTP with WebSocket proxy support
// If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager
if certManager == nil { if tlsConfig == nil {
// Without TLS, serve plain HTTP // Without TLS, serve plain HTTP
httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort))
if err != nil { if err != nil {
@@ -140,9 +140,10 @@ var (
} }
log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String())
serveHTTP(httpListener, grpcRootHandler) serveHTTP(httpListener, grpcRootHandler)
} else if signalPort != 443 { } else if certManager == nil || signalPort != 443 {
// With TLS but not on port 443, serve HTTPS // Serve HTTPS if not already handled by startServerWithCertManager
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) // (custom certificates or Let's Encrypt with custom port)
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig)
if err != nil { if err != nil {
return err return err
} }
@@ -202,7 +203,7 @@ func startPprof() {
}() }()
} }
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) {
var ( var (
err error err error
certManager *autocert.Manager certManager *autocert.Manager
@@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
log.Infof("running without TLS") log.Infof("running without TLS")
return nil, nil, nil return nil, nil, nil, nil
} }
if signalLetsencryptDomain != "" { if signalLetsencryptDomain != "" {
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
if err != nil { if err != nil {
return nil, certManager, err return nil, certManager, nil, err
} }
tlsConfig = certManager.TLSConfig() tlsConfig = certManager.TLSConfig()
log.Infof("setting up TLS with LetsEncrypt.") log.Infof("setting up TLS with LetsEncrypt.")
} else { } else {
if signalCertFile == "" || signalCertKey == "" { if signalCertFile == "" || signalCertKey == "" {
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
} }
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
if err != nil { if err != nil {
log.Errorf("cannot load TLS credentials: %v", err) log.Errorf("cannot load TLS credentials: %v", err)
return nil, certManager, err return nil, certManager, nil, err
} }
log.Infof("setting up TLS with custom certificates.") log.Infof("setting up TLS with custom certificates.")
} }
transportCredentials := credentials.NewTLS(tlsConfig) transportCredentials := credentials.NewTLS(tlsConfig)
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err
} }
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) { func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {