Compare commits

...

37 Commits

Author SHA1 Message Date
Viktor Liu
1ee575befe [client] Use management-provided dns forwarder port on the client side (#4712) 2025-10-28 22:58:43 +01:00
Viktor Liu
d3a34adcc9 [client] Fix Connect/Disconnect buttons being enabled or disabled at the same time (#4711) 2025-10-28 21:21:40 +01:00
Zoltan Papp
d7321c130b [client] The status cmd will not be blocked by the ICE probe (#4597)
The status cmd will not be blocked by the ICE probe

Refactor the TURN and STUN probe, and cache the results. The NetBird status command will indicate a "checking…" state.
2025-10-28 16:11:35 +01:00
Viktor Liu
404cab90ba [client] Redirect dns forwarder port 5353 to new listening port 22054 (#4707)
- Port dnat changes from https://github.com/netbirdio/netbird/pull/4015 (nftables/iptables/userspace)
  - For userspace: rewrite the original port to the target port
  - Remember original destination port in conntrack
  - Rewrite the source port back to the original port for replies
- Redirect incoming port 5353 to 22054 (tcp/udp)
- Revert port changes based on the network map received from management
- Adjust tracer to show NAT stages
2025-10-28 15:12:53 +01:00
Pascal Fischer
4545ab9a52 [management] rewire account manager to permissions manager (#4673) 2025-10-27 22:59:35 +01:00
Bethuel Mmbaga
7f08983207 Include expired and routing peers in DNS record filtering (#4708) 2025-10-27 22:16:17 +03:00
Viktor Liu
eddea14521 [client] Clean up bsd routes independently of the state file (#4688) 2025-10-27 18:54:00 +01:00
Viktor Liu
b9ef214ea5 [client] Fix macOS state-based dns cleanup (#4701) 2025-10-27 18:35:32 +01:00
Bethuel Mmbaga
709e24eb6f [signal] Fix HTTP/WebSocket proxy not using custom certificates (#4644)
This pull request fixes a bug where the HTTP/WebSocket proxy server was not using custom TLS certificates when provided via --cert-file and --cert-key flags. Previously, only the gRPC server had TLS enabled with custom certificates, while the HTTP/WebSocket proxy ran without TLS.
2025-10-24 15:40:20 +03:00
Viktor Liu
6654e2dbf7 [client] Fix active profile name in debug bundle (#4689) 2025-10-23 17:07:52 +02:00
Bethuel Mmbaga
d80d47a469 [management] Add peer disapproval reason (#4468) 2025-10-22 12:46:22 +03:00
Maycon Santos
96f71ff1e1 [misc] Update tag name extraction in install.sh (#4677) 2025-10-21 19:23:11 +02:00
Viktor Liu
2fe2af38d2 [client] Clean up match domain reg entries between config changes (#4676) 2025-10-21 18:14:39 +02:00
Misha Bragin
cd9a867ad0 [client] Delete TURNConfig section from script (#4639) 2025-10-17 19:48:26 +02:00
Maycon Santos
0f9bfeff7c [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 2025-10-17 19:47:11 +02:00
Viktor Liu
f5301230bf [client] Fix status showing P2P without connection (#4661) 2025-10-17 13:31:15 +02:00
Viktor Liu
429d7d6585 [client] Support BROWSER env for login (#4654) 2025-10-17 11:10:16 +02:00
Viktor Liu
3cdb10cde7 [client] Remove rule squashing (#4653) 2025-10-17 11:09:39 +02:00
Zoltan Papp
af95aabb03 Handle the case when the service has already been down and the status recorder is not available (#4652) 2025-10-16 17:15:39 +02:00
Viktor Liu
3abae0bd17 [client] Set default wg port for new profiles (#4651) 2025-10-16 16:16:51 +02:00
Viktor Liu
8252ff41db [client] Add bind activity listener to bypass udp sockets (#4646) 2025-10-16 15:58:29 +02:00
Viktor Liu
277aa2b7cc [client] Fix missing flag values in profiles (#4650) 2025-10-16 15:13:41 +02:00
John Conley
bb37dc89ce [management] feat: Basic PocketID IDP integration (#4529) 2025-10-16 10:46:29 +02:00
Viktor Liu
000e99e7f3 [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) 2025-10-13 17:50:16 +02:00
Maycon Santos
0d2e67983a [misc] Add service definition for netbird-signal (#4620) 2025-10-10 19:16:48 +02:00
Pascal Fischer
5151f19d29 [management] pass temporary flag to validator (#4599) 2025-10-10 16:15:51 +02:00
Kostya Leschenko
bedd3cabc9 [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) 2025-10-10 15:24:24 +02:00
hakansa
d35a845dbd [management] sync all other peers on peer add/remove (#4614) 2025-10-09 21:18:00 +02:00
hakansa
4e03f708a4 fix dns forwarder port update (#4613)
fix dns forwarder port update (#4613)
2025-10-09 17:39:02 +03:00
Ashley
654aa9581d [client,gui] Update url_windows.go to offer arm64 executable download (#4586) 2025-10-08 21:27:32 +02:00
Zoltan Papp
9021bb512b [client] Recreate agent when receive new session id (#4564)
When an ICE agent connection was in progress, new offers were being ignored. This was incorrect logic because the remote agent could be restarted at any time.
In this change, whenever a new session ID is received, the ongoing handshake is closed and a new one is started.
2025-10-08 17:14:24 +02:00
hakansa
768332820e [client] Implement DNS query caching in DNSForwarder (#4574)
implements DNS query caching in the DNSForwarder to improve performance and provide fallback responses when upstream DNS servers fail. The cache stores successful DNS query results and serves them when upstream resolution fails.

- Added a new cache component to store DNS query results by domain and query type
- Integrated cache storage after successful DNS resolutions
- Enhanced error handling to serve cached responses as fallback when upstream DNS fails
2025-10-08 16:54:27 +02:00
hakansa
229c65ffa1 Enhance showLoginURL to include connection status check and auto-close functionality (#4525) 2025-10-08 12:42:15 +02:00
Zoltan Papp
4d33567888 [client] Remove endpoint address on peer disconnect, retain status for activity recording (#4228)
* When a peer disconnects, remove the endpoint address to avoid sending traffic to a non-existent address, but retain the status for the activity recorder.
2025-10-08 03:12:16 +02:00
Viktor Liu
88467883fc [management,signal] Remove ws-proxy read deadline (#4598) 2025-10-06 22:05:48 +02:00
Viktor Liu
954f40991f [client,management,signal] Handle grpc from ws proxy internally instead of via tcp (#4593) 2025-10-06 21:22:19 +02:00
Maycon Santos
34341d95a9 Adjust signal port for websocket connections (#4594) 2025-10-06 15:22:02 -03:00
134 changed files with 4642 additions and 1420 deletions

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0 FROM alpine:3.22.2
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \

View File

@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
if err != nil { if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
} }
@@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
func getStatusOutput(cmd *cobra.Command, anon bool) string { func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string var statusOutputString string
statusResp, err := getStatus(cmd.Context()) statusResp, err := getStatus(cmd.Context(), true)
if err != nil { if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
) )
} }
return statusOutputString return statusOutputString

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/exec"
"os/user" "os/user"
"runtime" "runtime"
"strings" "strings"
@@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil { if err := openBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }
} }
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx) resp, err := getStatus(ctx, false)
if err != nil { if err != nil {
return err return err
} }
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatus(ctx context.Context) (*proto.StatusResponse, error) { func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
} }
defer conn.Close() defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil { if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
} }

View File

@@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return "" return ""
} }
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := "" actionSuffix := ""
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
actionSuffix = "-drop" actionSuffix = "-drop"

View File

@@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes) return m.router.UpdateSet(set, prefixes)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = ruleInfo.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {
if port == nil { if port == nil {
return nil return nil

View File

@@ -151,14 +151,20 @@ type Manager interface {
DisableRouting() error DisableRouting() error
// AddDNATRule adds a DNAT rule // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
AddDNATRule(ForwardRule) (Rule, error) AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule // DeleteDNATRule deletes the outbound DNAT rule.
DeleteDNATRule(Rule) error DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes // UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes) return m.router.UpdateSet(set, prefixes)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets // applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork( func (r *router) applyNetwork(
network firewall.Network, network firewall.Network,

View File

@@ -22,6 +22,8 @@ type BaseConnTrack struct {
PacketsRx atomic.Uint64 PacketsRx atomic.Uint64
BytesTx atomic.Uint64 BytesTx atomic.Uint64
BytesRx atomic.Uint64 BytesRx atomic.Uint64
DNATOrigPort atomic.Uint32
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler

View File

@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
if exists { if exists {
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
return key, true return key, uint16(conn.DNATOrigPort.Load()), true
} }
return key, false return key, 0, false
} }
// TrackOutbound records an outbound TCP connection // TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
// if (inverted direction) conn is not tracked, track this direction return origPort
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
} }
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
return 0
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 { if exists || flags&TCPSyn == 0 {
return return
} }
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false) conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew)) conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
t.logger.Trace2("New %s TCP connection: %s", direction, key) if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
t.mutex.Lock() t.mutex.Lock()
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
} }
} }
// GetConnection safely retrieves a connection state
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.tickerCancel() t.tickerCancel()

View File

@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
serverPort := uint16(80) serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound) // 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
key := ConnKey{ key := ConnKey{
SrcIP: clientIP, SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake // 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer // 4. Test data transfer
// Client sends data // Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
// Server sends ACK for data // Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data // Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
// Verify state and counters // Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState()) require.Equal(t, TCPStateEstablished, conn.GetState())

View File

@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
// if (inverted direction) conn is not tracked, track this direction if exists {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) return origPort
} }
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
return 0
} }
// TrackInbound records an inbound UDP connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
} }
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists { if exists {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
return key, true return key, uint16(conn.DNATOrigPort.Load()), true
} }
return key, false return key, 0, false
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists { if exists {
return return
} }
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
} }
conn.DNATOrigPort.Store(uint32(origPort))
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace2("New %s UDP connection: %s", direction, key) if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.sendEvent(nftypes.TypeStart, conn, ruleID) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }

View File

@@ -109,6 +109,10 @@ type Manager struct {
dnatMappings map[netip.Addr]netip.Addr dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex dnatMutex sync.RWMutex
dnatBiMap *biDNATMap dnatBiMap *biDNATMap
portDNATEnabled atomic.Bool
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex
} }
// decoder for packages // decoder for packages
@@ -122,6 +126,8 @@ type decoder struct {
icmp6 layers.ICMPv6 icmp6 layers.ICMPv6
decoded []gopacket.LayerType decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
@@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr), dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
} }
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
@@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true return true
} }
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, packetData, size)
m.translateOutboundDNAT(packetData, d) m.translateOutboundDNAT(packetData, d)
return false return false
@@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
if origPort == 0 {
break
}
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite UDP port: %v", err)
}
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
if origPort == 0 {
break
}
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite TCP port: %v", err)
}
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
} }
@@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
} }
d.dnatOrigPort = 0
} }
// udpHooksDrop checks if any UDP hooks should drop the packet // udpHooksDrop checks if any UDP hooks should drop the packet
@@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false return false
} }
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if translated := m.translateInboundReverse(packetData, d); translated { if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses // Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true return true
} }
srcIP, dstIP = m.extractIPs(d) srcIP, dstIP = m.extractIPs(d)

View File

@@ -50,6 +50,8 @@ type logMessage struct {
arg4 any arg4 any
arg5 any arg5 any
arg6 any arg6 any
arg7 any
arg8 any
} }
// Logger is a high-performance, non-blocking logger // Logger is a high-performance, non-blocking logger
@@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
func (l *Logger) Error(format string) { func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
@@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
} }
} }
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
@@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
} }
} }
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0] *buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
@@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
argCount++ argCount++
if msg.arg6 != nil { if msg.arg6 != nil {
argCount++ argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
} }
} }
} }
@@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6: case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
case 7:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
case 8:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
} }
*buf = append(*buf, formatted...) *buf = append(*buf, formatted...)
@@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error {
case <-done: case <-done:
return nil return nil
} }
} }

View File

@@ -5,7 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -13,6 +15,21 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
const (
// Port offsets in TCP/UDP headers
sourcePortOffset = 0
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 { func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 { if len(header) < 20 {
return 0 return 0
@@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// icmpChecksum calculates ICMP checksum.
func icmpChecksum(data []byte) uint16 { func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32 var sum1, sum2, sum3, sum4 uint32
i := 0 i := 0
@@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// biDNATMap maintains bidirectional DNAT mappings.
type biDNATMap struct { type biDNATMap struct {
forward map[netip.Addr]netip.Addr forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr
} }
// portDNATRule represents a port-specific DNAT rule.
type portDNATRule struct {
protocol gopacket.LayerType
origPort uint16
targetPort uint16
targetIP netip.Addr
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
func newBiDNATMap() *biDNATMap { func newBiDNATMap() *biDNATMap {
return &biDNATMap{ return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr), forward: make(map[netip.Addr]netip.Addr),
@@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap {
} }
} }
// set adds a bidirectional DNAT mapping between original and translated addresses.
func (b *biDNATMap) set(original, translated netip.Addr) { func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated b.forward[original] = translated
b.reverse[translated] = original b.reverse[translated] = original
} }
// delete removes a bidirectional DNAT mapping for the given original address.
func (b *biDNATMap) delete(original netip.Addr) { func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists { if translated, exists := b.forward[original]; exists {
delete(b.forward, original) delete(b.forward, original)
@@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
} }
} }
// getTranslated returns the translated address for a given original address.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original] translated, exists := b.forward[original]
return translated, exists return translated, exists
} }
// getOriginal returns the original address for a given translated address.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated] original, exists := b.reverse[translated]
return original, exists return original, exists
} }
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() { if !originalAddr.IsValid() {
return fmt.Errorf("invalid IP addresses") return fmt.Errorf("invalid original IP address")
}
if !translatedAddr.IsValid() {
return fmt.Errorf("invalid translated IP address")
} }
if m.localipmanager.IsLocalIP(translatedAddr) { if m.localipmanager.IsLocalIP(translatedAddr) {
@@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock() m.dnatMutex.Lock()
defer m.dnatMutex.Unlock() defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil { if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap() m.dnatBiMap = newBiDNATMap()
@@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
return nil return nil
} }
// RemoveInternalDNATMapping removes a 1:1 IP address mapping // RemoveInternalDNATMapping removes a 1:1 IP address mapping.
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock() m.dnatMutex.Lock()
defer m.dnatMutex.Unlock() defer m.dnatMutex.Unlock()
@@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil return nil
} }
// getDNATTranslation returns the translated address if a mapping exists // getDNATTranslation returns the translated address if a mapping exists.
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return addr, false return addr, false
@@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists return translated, exists
} }
// findReverseDNATMapping finds original address for return traffic // findReverseDNATMapping finds original address for return traffic.
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return translatedAddr, false return translatedAddr, false
@@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists return original, exists
} }
// translateOutboundDNAT applies DNAT translation to outbound packets // translateOutboundDNAT applies DNAT translation to outbound packets.
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return false return false
} }
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP) translatedIP, exists := m.getDNATTranslation(dstIP)
@@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false return false
} }
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
m.logger.Error1("Failed to rewrite packet destination: %v", err) m.logger.Error1("failed to rewrite packet destination: %v", err)
return false return false
} }
@@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true return true
} }
// translateInboundReverse applies reverse DNAT to inbound return traffic // translateInboundReverse applies reverse DNAT to inbound return traffic.
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return false return false
} }
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP) originalIP, exists := m.findReverseDNATMapping(srcIP)
@@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false return false
} }
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
m.logger.Error1("Failed to rewrite packet source: %v", err) m.logger.Error1("failed to rewrite packet source: %v", err)
return false return false
} }
@@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true return true
} }
// rewritePacketDestination replaces destination IP in the packet // rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { if !newIP.Is4() {
return ErrIPv4Only return ErrIPv4Only
} }
var oldDst [4]byte var oldIP [4]byte
copy(oldDst[:], packetData[16:20]) copy(oldIP[:], packetData[ipOffset:ipOffset+4])
newDst := newIP.As4() newIPBytes := newIP.As4()
copy(packetData[16:20], newDst[:]) copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
ipHeaderLen := int(d.ip4.IHL) * 4 ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length") return errInvalidIPHeaderLength
} }
binary.BigEndian.PutUint16(packetData[10:12], 0) binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
if len(d.decoded) > 1 { if len(d.decoded) > 1 {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen) m.updateICMPChecksum(packetData, ipHeaderLen)
} }
@@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
return nil return nil
} }
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 { if len(packetData) < tcpStart+18 {
@@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
} }
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen udpStart := ipHeaderLen
if len(packetData) < udpStart+8 { if len(packetData) < udpStart+8 {
@@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
} }
// updateICMPChecksum recalculates ICMP checksum after packet modification.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 { if len(packetData) < icmpStart+8 {
@@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum) binary.BigEndian.PutUint16(icmpData[2:4], checksum)
} }
// incrementalUpdate performs incremental checksum update per RFC 1624 // incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum) sum := uint32(^oldChecksum)
@@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) // AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return nil, errNatNotSupported return nil, errNatNotSupported
@@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
return m.nativeFirewall.AddDNATRule(rule) return m.nativeFirewall.AddDNATRule(rule)
} }
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) // DeleteDNATRule deletes outbound DNAT rule.
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errNatNotSupported return errNatNotSupported
} }
return m.nativeFirewall.DeleteDNATRule(rule) return m.nativeFirewall.DeleteDNATRule(rule)
} }
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATRules = append(m.portDNATRules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATRules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {
return false
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort := uint16(d.tcp.DstPort)
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
case layers.LayerTypeUDP:
dstPort := uint16(d.udp.DstPort)
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
default:
return false
}
}
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATRules {
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
continue
}
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
return false
}
if rule.origPort != port {
continue
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort
return true
}
return false
}
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
portStart := tcpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}
portStart := udpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
checksumOffset := udpStart + 6
if len(packetData) >= udpStart+8 {
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum != 0 {
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
}
return nil
}

View File

@@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
} }
}) })
} }
// BenchmarkPortDNAT measures the performance of port DNAT operations
func BenchmarkPortDNAT(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
useMatchPort bool
description string
}{
{
name: "tcp_inbound_dnat_match",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: true,
description: "TCP inbound port DNAT translation (22 → 22022)",
},
{
name: "tcp_inbound_dnat_nomatch",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: false,
description: "TCP inbound with DNAT configured but no port match",
},
{
name: "tcp_inbound_no_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
useMatchPort: false,
description: "TCP inbound without DNAT (baseline)",
},
{
name: "udp_inbound_dnat_match",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: true,
description: "UDP inbound port DNAT translation (5353 → 22054)",
},
{
name: "udp_inbound_dnat_nomatch",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: false,
description: "UDP inbound with DNAT configured but no port match",
},
{
name: "udp_inbound_no_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
useMatchPort: false,
description: "UDP inbound without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
var origPort, targetPort, testPort uint16
if sc.proto == layers.IPProtocolTCP {
origPort, targetPort = 22, 22022
} else {
origPort, targetPort = 5353, 22054
}
if sc.useMatchPort {
testPort = origPort
} else {
testPort = 443 // Different port
}
// Setup port DNAT mapping if needed
if sc.setupDNAT {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
require.NoError(b, err)
}
// Pre-establish inbound connection for outbound reverse test
if sc.setupDNAT && sc.useMatchPort {
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
manager.filterInbound(inboundPacket, 0)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark inbound DNAT translation
b.Run("inbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
manager.filterInbound(packet, 0)
}
})
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
if sc.setupDNAT && sc.useMatchPort {
b.Run("outbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh return packet (from target port)
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
manager.filterOutbound(packet, 0)
}
})
}
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) {
err = manager.RemoveInternalDNATMapping(originalIP) err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping") require.Error(t, err, "Should error when removing non-existent mapping")
} }
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
testCases := []struct {
name string
protocol layers.IPProtocol
sourcePort uint16
targetPort uint16
}{
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
d := parsePacket(t, inboundPacket)
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
require.True(t, translated, "Inbound packet should be translated")
d = parsePacket(t, inboundPacket)
var dstPort uint16
switch tc.protocol {
case layers.IPProtocolTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.IPProtocolUDP:
dstPort = uint16(d.udp.DstPort)
}
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
})
}
}
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
d := parsePacket(t, packet)
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})
}
}
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
switch proto {
case layers.IPProtocolTCP:
return firewall.ProtocolTCP
case layers.IPProtocolUDP:
return firewall.ProtocolUDP
default:
return firewall.ProtocolALL
}
}

View File

@@ -16,25 +16,33 @@ type PacketStage int
const ( const (
StageReceived PacketStage = iota StageReceived PacketStage = iota
StageInboundPortDNAT
StageInbound1to1NAT
StageConntrack StageConntrack
StagePeerACL StagePeerACL
StageRouting StageRouting
StageRouteACL StageRouteACL
StageForwarding StageForwarding
StageCompleted StageCompleted
StageOutbound1to1NAT
StageOutboundPortReverse
) )
const msgProcessingCompleted = "Processing completed" const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string { func (s PacketStage) String() string {
return map[PacketStage]string{ return map[PacketStage]string{
StageReceived: "Received", StageReceived: "Received",
StageConntrack: "Connection Tracking", StageInboundPortDNAT: "Inbound Port DNAT",
StagePeerACL: "Peer ACL", StageInbound1to1NAT: "Inbound 1:1 NAT",
StageRouting: "Routing", StageConntrack: "Connection Tracking",
StageRouteACL: "Route ACL", StagePeerACL: "Peer ACL",
StageForwarding: "Forwarding", StageRouting: "Routing",
StageCompleted: "Completed", StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageOutbound1to1NAT: "Outbound 1:1 NAT",
StageOutboundPortReverse: "Outbound DNAT Reverse",
}[s] }[s]
} }
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
return trace
}
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
} }
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
m.handleOutboundDNAT(trace, packetData, d)
dropped := m.filterOutbound(packetData, 0) dropped := m.filterOutbound(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
} }
return trace return trace
} }
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
trace.DestinationPort = m.getDestPort(d)
}
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
}
return false
}
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
return false
}
protocol := d.decoded[1]
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var originalPort uint16
if protocol == layers.LayerTypeTCP {
originalPort = uint16(d.tcp.DstPort)
} else {
originalPort = uint16(d.udp.DstPort)
}
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
if translated {
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
protoStr := "TCP"
if protocol == layers.LayerTypeUDP {
protoStr = "UDP"
}
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
trace.AddResult(StageInboundPortDNAT, msg, true)
return true
}
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
return false
}
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
trace.AddResult(StageInbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
m.traceOutbound1to1NAT(trace, packetData, d)
m.traceOutboundPortReverse(trace, packetData, d)
}
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatMappings[dstIP]
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
trace.AddResult(StageOutbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var origPort uint16
transport := d.decoded[1]
switch transport {
case layers.LayerTypeTCP:
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
case layers.LayerTypeUDP:
srcPort := uint16(d.udp.SrcPort)
dstPort := uint16(d.udp.DstPort)
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
default:
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
return false
}
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
return false
}
func (m *Manager) getDestPort(d *decoder) uint16 {
if len(d.decoded) < 2 {
return 0
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.DstPort)
default:
return 0
}
}

View File

@@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageCompleted, StageCompleted,
@@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageCompleted, StageCompleted,
}, },
@@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted, StageCompleted,
}, },
expectedAllow: true, expectedAllow: true,
@@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
StageCompleted, StageCompleted,

View File

@@ -29,7 +29,8 @@ func Backoff(ctx context.Context) backoff.BackOff {
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled { // for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil { if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -37,9 +38,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
} }
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
// for js, outer websocket layer takes care of tls verification via WithCustomDialer RootCAs: certPool,
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool,
})) }))
} }

View File

@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error { func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -23,4 +23,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device Device() *wgdevice.Device
GetNet() *netstack.Net GetNet() *netstack.Net
GetICEBind() device.EndpointManager
} }

View File

@@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *WGTunDevice) GetICEBind() EndpointManager {
return t.iceBind
}
func routesToString(routes []string) string { func routesToString(routes []string) string {
return strings.Join(routes, ";") return strings.Join(routes, ";")
} }

View File

@@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net { func (t *TunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
func (t *TunDevice) GetNet() *netstack.Net { func (t *TunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error {
func (t *TunKernelDevice) GetNet() *netstack.Net { func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns nil for kernel mode devices
func (t *TunKernelDevice) GetICEBind() EndpointManager {
return nil
}

View File

@@ -21,6 +21,7 @@ type Bind interface {
conn.Bind conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder ActivityRecorder() *bind.ActivityRecorder
EndpointManager
} }
type TunNetstackDevice struct { type TunNetstackDevice struct {
@@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device {
func (t *TunNetstackDevice) GetNet() *netstack.Net { func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net return t.net
} }
// GetICEBind returns the bind instance
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
return t.bind
}

View File

@@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error {
func (t *USPDevice) GetNet() *netstack.Net { func (t *USPDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net { func (t *TunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -0,0 +1,13 @@
package device
import (
"net"
"net/netip"
)
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
// Implemented by bind.ICEBind and bind.RelayBindJS.
type EndpointManager interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
}

View File

@@ -21,4 +21,5 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
} }

View File

@@ -21,4 +21,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device Device() *wgdevice.Device
GetNet() *netstack.Net GetNet() *netstack.Net
GetICEBind() device.EndpointManager
} }

View File

@@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy() return w.wgProxyFactory.GetProxy()
} }
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun == nil {
return nil
}
return w.tun.GetICEBind()
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool { func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind return w.userspaceBind
@@ -148,6 +159,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock() w.mu.Lock()

View File

@@ -29,11 +29,6 @@ type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
} }
type protoMatch struct {
ips map[string]int
policyID []byte
}
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
firewall firewall.Manager firewall firewall.Manager
@@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
} }
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap) rules := networkMap.FirewallRules
enableSSH := networkMap.PeerConfig != nil && enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled networkMap.PeerConfig.SshConfig.SshEnabled
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
// if TCP protocol rules not squashed and SSH enabled // If SSH enabled, add default firewall rule which accepts connection to any peer
// we add default firewall rule which accepts connection to any peer // in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
// in the network by SSH (TCP 22 port).
if enableSSH { if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{ rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
@@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID(
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
} }
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic.
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
totalIPs++
}
}
in := map[mgmProto.RuleProtocol]*protoMatch{}
out := map[mgmProto.RuleProtocol]*protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if:
// 1. Any of the rule has DROP action type.
// 2. Any of rule contains Port.
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
}
// special case, when we receive this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return
}
ipset := protocols[r.Protocol].ips
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
}
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
}
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
protocolOrders := []mgmProto.RuleProtocol{
mgmProto.RuleProtocol_ALL,
mgmProto.RuleProtocol_ICMP,
mgmProto.RuleProtocol_TCP,
mgmProto.RuleProtocol_UDP,
}
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
// don't squash if :
// 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network.
continue
}
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
PolicyID: match.policyID,
})
squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
}
}
}
squash(in, mgmProto.RuleDirection_IN)
squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
if len(squashedRules) == 0 {
return networkMap.FirewallRules, squashedProtocols
}
var rules []*mgmProto.FirewallRule
// filter out rules which was squashed from final list
// if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
}
}
rules = append(rules, r)
}
return append(rules, squashedRules...), squashedProtocols
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector // getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)

View File

@@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) {
}) })
} }
func TestDefaultManagerSquashRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
}
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) { func TestPortInfoEmpty(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client. config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states. state.json: Anonymized client state dump containing netbird states for the active profile.
mutex.prof: Mutex profiling information. mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information. goroutine.prof: Goroutine profiling information.
block.prof: Block profiling information. block.prof: Block profiling information.
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
return nil return nil
} }
log.Debugf("Adding state file from: %s", path)
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {

View File

@@ -14,6 +14,9 @@ type WGIface interface {
} }
func (g *BundleGenerator) addWgShow() error { func (g *BundleGenerator) addWgShow() error {
if g.statusRecorder == nil {
return fmt.Errorf("no status recorder available for wg show")
}
result, err := g.statusRecorder.PeersStatus() result, err := g.statusRecorder.PeersStatus()
if err != nil { if err != nil {
return err return err

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
} }
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var ( var (
searchDomains []string searchDomains []string
matchDomains []string matchDomains []string
) )
err = s.recordSystemDNSSettings(true) if err := s.recordSystemDNSSettings(true); err != nil {
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error()) log.Errorf("unable to update record of System's DNS config: %s", err.Error())
} }
if config.RouteAll { if config.RouteAll {
searchDomains = append(searchDomains, "\"\"") searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS() if err := s.addLocalDNS(); err != nil {
if err != nil { log.Warnf("failed to add local DNS: %v", err)
log.Infof("failed to enable split DNS")
} }
s.updateState(stateManager)
} }
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else { } else {
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil { if err != nil {
return fmt.Errorf("add match domains: %w", err) return fmt.Errorf("add match domains: %w", err)
} }
s.updateState(stateManager)
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 { if len(searchDomains) != 0 {
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil { if err != nil {
return fmt.Errorf("add search domains: %w", err) return fmt.Errorf("add search domains: %w", err)
} }
s.updateState(stateManager)
if err := s.flushDNSCache(); err != nil { if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err) log.Errorf("failed to flush DNS cache: %v", err)
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
return nil return nil
} }
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (s *systemConfigurator) string() string { func (s *systemConfigurator) string() string {
return "scutil" return "scutil"
} }
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error { func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
if err := s.recordSystemDNSSettings(true); err != nil { if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("Unable to get system DNS configuration")
return fmt.Errorf("recordSystemDNSSettings(): %w", err) return fmt.Errorf("recordSystemDNSSettings(): %w", err)
} }
} }
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server") log.Info("Not enabling local DNS server")
return nil
}
if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
} }
return nil return nil

View File

@@ -0,0 +1,111 @@
//go:build !ios
package dns
import (
"context"
"net/netip"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
defer func() {
require.NoError(t, sm.Stop(context.Background()))
}()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
config := HostDNSConfig{
ServerIP: netip.MustParseAddr("100.64.0.1"),
ServerPort: 53,
RouteAll: true,
Domains: []DomainConfig{
{Domain: "example.com", MatchOnly: true},
},
}
err := configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
require.NoError(t, sm.PersistState(context.Background()))
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
defer func() {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}()
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
t.Logf("Key %s exists before cleanup", key)
}
}
sm2 := statemanager.New(stateFile)
sm2.RegisterState(&ShutdownState{})
err = sm2.LoadState(&ShutdownState{})
require.NoError(t, err)
state := sm2.GetState(&ShutdownState{})
if state == nil {
t.Skip("State not saved, skipping cleanup test")
}
shutdownState, ok := state.(*ShutdownState)
require.True(t, ok)
err = shutdownState.Cleanup()
require.NoError(t, err)
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}
func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
output, err := cmd.CombinedOutput()
if err != nil {
if strings.Contains(string(output), "No such key") {
return false, nil
}
return false, err
}
return !strings.Contains(string(output), "No such key"), nil
}
func removeTestDNSKey(key string) error {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
_, err := cmd.CombinedOutput()
return err
}

View File

@@ -17,6 +17,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/winregistry"
) )
var ( var (
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{ r.updateState(stateManager)
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var searchDomains, matchDomains []string var searchDomains, matchDomains []string
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
} }
if err := r.removeDNSMatchPolicies(); err != nil {
log.Errorf("cleanup old dns match policies: %s", err)
}
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil { if err != nil {
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
r.nrptEntryCount = count r.nrptEntryCount = count
} else { } else {
if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err)
}
r.nrptEntryCount = 0 r.nrptEntryCount = 0
} }
if err := stateManager.UpdateState(&ShutdownState{ r.updateState(stateManager)
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
if err := r.updateSearchDomains(searchDomains); err != nil { if err := r.updateSearchDomains(searchDomains); err != nil {
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil return nil
} }
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err) return fmt.Errorf("adding dns setup for all failed: %w", err)
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
return fmt.Errorf("remove existing dns policy: %w", err) return fmt.Errorf("remove existing dns policy: %w", err)
} }
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
if err != nil { if err != nil {
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
} }
defer closer(regKey) defer closer(regKey)

View File

@@ -0,0 +1,102 @@
package dns
import (
"fmt"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/registry"
)
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config5, nil)
require.NoError(t, err)
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after first config", i)
}
config2 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config2, nil)
require.NoError(t, err)
// Verify first 2 entries exist
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after second config", i)
}
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
}
func registryKeyExists(path string) (bool, error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
if err == registry.ErrNotExist {
return false, nil
}
return false, err
}
k.Close()
return true, nil
}
func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10}
_ = cfg.removeDNSMatchPolicies()
}

View File

@@ -31,6 +31,7 @@ const (
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC" systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
systemdDbusResolvConfModeForeign = "foreign" systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
log.Warnf("failed to set DNSSEC to 'no': %v", err) log.Warnf("failed to set DNSSEC to 'no': %v", err)
} }
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
}
var ( var (
searchDomains []string searchDomains []string
matchDomains []string matchDomains []string

View File

@@ -7,6 +7,7 @@ import (
) )
type ShutdownState struct { type ShutdownState struct {
CreatedKeys []string
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create host manager: %w", err) return fmt.Errorf("create host manager: %w", err)
} }
for _, key := range s.CreatedKeys {
manager.createdKeys[key] = struct{}{}
}
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }

View File

@@ -0,0 +1,78 @@
package dnsfwd
import (
"net/netip"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
)
type cache struct {
mu sync.RWMutex
records map[string]*cacheEntry
}
type cacheEntry struct {
ip4Addrs []netip.Addr
ip6Addrs []netip.Addr
}
func newCache() *cache {
return &cache{
records: make(map[string]*cacheEntry),
}
}
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.records[normalizeDomain(domain)]
if !exists {
return nil, false
}
switch reqType {
case dns.TypeA:
return slices.Clone(entry.ip4Addrs), true
case dns.TypeAAAA:
return slices.Clone(entry.ip6Addrs), true
default:
return nil, false
}
}
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
norm := normalizeDomain(domain)
entry, exists := c.records[norm]
if !exists {
entry = &cacheEntry{}
c.records[norm] = entry
}
switch reqType {
case dns.TypeA:
entry.ip4Addrs = slices.Clone(addrs)
case dns.TypeAAAA:
entry.ip6Addrs = slices.Clone(addrs)
}
}
// unset removes cached entries for the given domain and request type.
func (c *cache) unset(domain string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.records, normalizeDomain(domain))
}
// normalizeDomain converts an input domain into a canonical form used as cache key:
// lowercase and fully-qualified (with trailing dot).
func normalizeDomain(domain string) string {
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
return dns.Fqdn(strings.ToLower(domain))
}

View File

@@ -0,0 +1,86 @@
package dnsfwd
import (
"net/netip"
"testing"
)
func mustAddr(t *testing.T, s string) netip.Addr {
t.Helper()
a, err := netip.ParseAddr(s)
if err != nil {
t.Fatalf("parse addr %s: %v", s, err)
}
return a
}
func TestCacheNormalization(t *testing.T) {
c := newCache()
// Mixed case, without trailing dot
domainInput := "ExAmPlE.CoM"
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
// Lookup with lower, with trailing dot
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
}
// Lookup with different casing again
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
}
}
func TestCacheSeparateTypes(t *testing.T) {
c := newCache()
domain := "test.local"
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
c.set(domain, 1 /* A */, ipv4)
c.set(domain, 28 /* AAAA */, ipv6)
got4, ok4 := c.get(domain, 1)
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
}
got6, ok6 := c.get(domain, 28)
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
}
}
func TestCacheCloneOnGetAndSet(t *testing.T) {
c := newCache()
domain := "clone.test"
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
c.set(domain, 1, src)
// Mutate source slice; cache should be unaffected
src[0] = mustAddr(t, "9.9.9.9")
got, ok := c.get(domain, 1)
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
}
// Mutate returned slice; internal cache should remain unchanged
got[0] = mustAddr(t, "4.4.4.4")
got2, ok2 := c.get(domain, 1)
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
}
}
func TestCacheMiss(t *testing.T) {
c := newCache()
if got, ok := c.get("missing.example", 1); ok || got != nil {
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
}
}

View File

@@ -46,6 +46,7 @@ type DNSForwarder struct {
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewaller firewall firewaller
resolver resolver resolver resolver
cache *cache
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
@@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
cache: newCache(),
} }
} }
@@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
// remove cache entries for domains that no longer appear
f.removeStaleCacheEntries(f.fwdEntries, entries)
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated DNS forwarder with %d domains", len(entries)) log.Debugf("Updated DNS forwarder with %d domains", len(entries))
} }
// removeStaleCacheEntries unsets cache items for domains that were present
// in the old list but not present in the new list.
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
if f.cache == nil {
return
}
newSet := make(map[string]struct{}, len(newEntries))
for _, e := range newEntries {
if e == nil {
continue
}
newSet[e.Domain.PunycodeString()] = struct{}{}
}
for _, e := range oldEntries {
if e == nil {
continue
}
pattern := e.Domain.PunycodeString()
if _, ok := newSet[pattern]; !ok {
f.cache.unset(pattern)
}
}
}
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
var result *multierror.Error var result *multierror.Error
@@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp return resp
} }
@@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
resp.Rcode = dns.RcodeSuccess resp.Rcode = dns.RcodeSuccess
} }
// handleDNSError processes DNS lookup errors and sends an appropriate error response // handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) { func (f *DNSForwarder) handleDNSError(
ctx context.Context,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
err error,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
switch { log.Warnf(errResolveFailed, domain, err)
case errors.As(err, &dnsErr): if writeErr := w.WriteMsg(resp); writeErr != nil {
resp.Rcode = dns.RcodeServerFailure log.Errorf("failed to write failure DNS response: %v", writeErr)
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
} }
return
}
if dnsErr.Server != "" { // NotFound: set NXDOMAIN / appropriate code via helper.
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err) if dnsErr.IsNotFound {
} else { f.setResponseCodeForNotFound(ctx, resp, domain, qType)
log.Warnf(errResolveFailed, domain, err) if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
} }
default: f.cache.set(domain, question.Qtype, nil)
resp.Rcode = dns.RcodeServerFailure return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
return
}
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err) log.Warnf(errResolveFailed, domain, err)
} }
if err := w.WriteMsg(resp); err != nil { // Write final failure response.
log.Errorf("failed to write failure DNS response: %v", err) if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
} }
} }

View File

@@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
} }
// Ensures that when the first query succeeds and populates the cache,
// a subsequent upstream failure still returns a successful response from cache.
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("1.2.3.4")
// First call resolves successfully and populates cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{ip}, nil).Once()
// Second call fails upstream; forwarder should serve from cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
// First query: populate cache
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
// Verifies that cache normalization works across casing and trailing dot variations.
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("ExAmPlE.CoM")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("9.8.7.6")
// Initial resolution with mixed case to populate cache
mixedQuery := "ExAmPlE.CoM"
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
Return([]netip.Addr{ip}, nil).Once()
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Subsequent query without dot and upper case should hit cache even if upstream fails
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios // Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{} mockFirewall := &MockFirewall{}

View File

@@ -4,7 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync" "net/netip"
"os"
"strconv"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -12,18 +14,14 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
var (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
listenPort uint16 = 5353
listenPortMu sync.RWMutex
)
const ( const (
dnsTTL = 60 //seconds dnsTTL = 60
envServerPort = "NB_DNS_FORWARDER_PORT"
) )
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
@@ -36,24 +34,30 @@ type ForwarderEntry struct {
type Manager struct { type Manager struct {
firewall firewall.Manager firewall firewall.Manager
statusRecorder *peer.Status statusRecorder *peer.Status
localAddr netip.Addr
serverPort uint16
fwRules []firewall.Rule fwRules []firewall.Rule
tcpRules []firewall.Rule tcpRules []firewall.Rule
dnsForwarder *DNSForwarder dnsForwarder *DNSForwarder
port uint16
} }
func ListenPort() uint16 { func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
listenPortMu.RLock() serverPort := nbdns.ForwarderServerPort
defer listenPortMu.RUnlock() if envPort := os.Getenv(envServerPort); envPort != "" {
return listenPort if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
} serverPort = uint16(port)
log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort)
} else {
log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort)
}
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
return &Manager{ return &Manager{
firewall: fw, firewall: fw,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
port: port, localAddr: localAddr,
serverPort: serverPort,
} }
} }
@@ -67,13 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err return err
} }
if m.port > 0 { if m.localAddr.IsValid() && m.firewall != nil {
listenPortMu.Lock() if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
listenPort = m.port log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
listenPortMu.Unlock() } else {
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
}
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
} else {
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
}
} }
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
go func() { go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil { if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists // todo handle close error if it is exists
@@ -98,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error {
} }
var mErr *multierror.Error var mErr *multierror.Error
if m.localAddr.IsValid() && m.firewall != nil {
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
}
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
}
if err := m.dropDNSFirewall(); err != nil { if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err) mErr = multierror.Append(mErr, err)
} }
@@ -113,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error {
func (m *Manager) allowDNSFirewall() error { func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{ dport := &firewall.Port{
IsRange: false, IsRange: false,
Values: []uint16{ListenPort()}, Values: []uint16{m.serverPort},
} }
if m.firewall == nil { if m.firewall == nil {

View File

@@ -203,8 +203,7 @@ type Engine struct {
wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup wgIfaceMonitorWg sync.WaitGroup
// dns forwarder port probeStunTurn *relay.StunTurnProbe
dnsFwdPort uint16
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -247,7 +246,7 @@ func NewEngine(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
dnsFwdPort: dnsfwd.ListenPort(), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
} }
sm := profilemanager.NewServiceManager("") sm := profilemanager.NewServiceManager("")
@@ -1060,10 +1059,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
} }
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err) log.Errorf("failed to update dns server, err: %v", err)
} }
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1084,7 +1087,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules // Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
@@ -1208,10 +1211,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
if forwarderPort == 0 {
forwarderPort = nbdns.ForwarderClientPort
}
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(), ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0), CustomZones: make([]nbdns.CustomZone, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0),
ForwarderPort: forwarderPort,
} }
for _, zone := range protoDNSConfig.GetCustomZones() { for _, zone := range protoDNSConfig.GetCustomZones() {
@@ -1667,7 +1676,7 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states. // and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool { func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy() signalHealthy := e.signal.IsHealthy()
@@ -1699,8 +1708,12 @@ func (e *Engine) RunHealthProbes() bool {
} }
e.syncMsgMux.Unlock() e.syncMsgMux.Unlock()
var results []relay.ProbeResult
results := e.probeICE(stuns, turns) if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results) e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true relayHealthy := true
@@ -1717,13 +1730,6 @@ func (e *Engine) RunHealthProbes() bool {
return allHealthy return allHealthy
} }
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
)
}
// restartEngine restarts the engine by cancelling the client context // restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() { func (e *Engine) restartEngine() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
@@ -1843,7 +1849,6 @@ func (e *Engine) GetWgAddr() netip.Addr {
func (e *Engine) updateDNSForwarder( func (e *Engine) updateDNSForwarder(
enabled bool, enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry, fwdEntries []*dnsfwd.ForwarderEntry,
forwarderPort uint16,
) { ) {
if e.config.DisableServerRoutes { if e.config.DisableServerRoutes {
return return
@@ -1860,20 +1865,17 @@ func (e *Engine) updateDNSForwarder(
} }
if len(fwdEntries) > 0 { if len(fwdEntries) > 0 {
switch { if e.dnsForwardMgr == nil {
case e.dnsForwardMgr == nil: localAddr := e.wgInterface.Address().IP
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err) log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
log.Infof("started domain router service with %d entries", len(fwdEntries))
case e.dnsFwdPort != forwarderPort:
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
e.restartDnsFwd(fwdEntries, forwarderPort)
e.dnsFwdPort = forwarderPort
default: log.Infof("started domain router service with %d entries", len(fwdEntries))
} else {
e.dnsForwardMgr.UpdateDomains(fwdEntries) e.dnsForwardMgr.UpdateDomains(fwdEntries)
} }
} else if e.dnsForwardMgr != nil { } else if e.dnsForwardMgr != nil {
@@ -1883,20 +1885,6 @@ func (e *Engine) updateDNSForwarder(
} }
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
}
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
// stop and start the forwarder to apply the new port
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
} }
func (e *Engine) GetNet() (*netstack.Net, error) { func (e *Engine) GetNet() (*netstack.Net, error) {

View File

@@ -105,6 +105,10 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time LastActivitiesFunc func() map[string]monotime.Time
} }
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) { func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented") return nil, fmt.Errorf("not implemented")
} }

View File

@@ -28,6 +28,7 @@ type wgIfaceBase interface {
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error

View File

@@ -0,0 +1,82 @@
package activity
import (
"context"
"io"
"net"
"time"
)
// lazyConn detects activity when WireGuard attempts to send packets.
// It does not deliver packets, only signals that activity occurred.
type lazyConn struct {
activityCh chan struct{}
ctx context.Context
cancel context.CancelFunc
}
// newLazyConn creates a new lazyConn for activity detection.
func newLazyConn() *lazyConn {
ctx, cancel := context.WithCancel(context.Background())
return &lazyConn{
activityCh: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
}
}
// Read blocks until the connection is closed.
func (c *lazyConn) Read(_ []byte) (n int, err error) {
<-c.ctx.Done()
return 0, io.EOF
}
// Write signals activity detection when ICEBind routes packets to this endpoint.
func (c *lazyConn) Write(b []byte) (n int, err error) {
if c.ctx.Err() != nil {
return 0, io.EOF
}
select {
case c.activityCh <- struct{}{}:
default:
}
return len(b), nil
}
// ActivityChan returns the channel that signals when activity is detected.
func (c *lazyConn) ActivityChan() <-chan struct{} {
return c.activityCh
}
// Close closes the connection.
func (c *lazyConn) Close() error {
c.cancel()
return nil
}
// LocalAddr returns the local address.
func (c *lazyConn) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
}
// RemoteAddr returns the remote address.
func (c *lazyConn) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
}
// SetDeadline sets the read and write deadlines.
func (c *lazyConn) SetDeadline(_ time.Time) error {
return nil
}
// SetReadDeadline sets the deadline for future Read calls.
func (c *lazyConn) SetReadDeadline(_ time.Time) error {
return nil
}
// SetWriteDeadline sets the deadline for future Write calls.
func (c *lazyConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@@ -0,0 +1,127 @@
package activity
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
type bindProvider interface {
GetBind() device.EndpointManager
}
const (
// lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers.
// The actual routing is done via fakeIP in ICEBind, not by this port.
lazyBindPort = 17473
)
// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode.
type BindListener struct {
wgIface WgInterface
peerCfg lazyconn.PeerConfig
done sync.WaitGroup
lazyConn *lazyConn
bind device.EndpointManager
fakeIP netip.Addr
}
// NewBindListener creates a listener that passes data directly through bind using LazyConn.
// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range.
func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) {
fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs)
if err != nil {
return nil, fmt.Errorf("derive fake IP: %w", err)
}
d := &BindListener{
wgIface: wgIface,
peerCfg: cfg,
bind: bind,
fakeIP: fakeIP,
}
if err := d.setupLazyConn(); err != nil {
return nil, fmt.Errorf("setup lazy connection: %v", err)
}
d.done.Add(1)
return d, nil
}
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
if len(allowedIPs) == 0 {
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
}
ourNetwork := wgIface.Address().Network
var peerIP netip.Addr
for _, allowedIP := range allowedIPs {
ip := allowedIP.Addr()
if !ip.Is4() {
continue
}
if ourNetwork.Contains(ip) {
peerIP = ip
break
}
}
if !peerIP.IsValid() {
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
}
octets := peerIP.As4()
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
return fakeIP, nil
}
func (d *BindListener) setupLazyConn() error {
d.lazyConn = newLazyConn()
d.bind.SetEndpoint(d.fakeIP, d.lazyConn)
endpoint := &net.UDPAddr{
IP: d.fakeIP.AsSlice(),
Port: lazyBindPort,
}
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil)
}
// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed.
func (d *BindListener) ReadPackets() {
select {
case <-d.lazyConn.ActivityChan():
d.peerCfg.Log.Infof("activity detected via LazyConn")
case <-d.lazyConn.ctx.Done():
d.peerCfg.Log.Infof("exit from activity listener")
}
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.lazyConn.Close()
d.bind.RemoveEndpoint(d.fakeIP)
d.done.Done()
}
// Close stops the listener and cleans up resources.
func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
if err := d.lazyConn.Close(); err != nil {
d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err)
}
d.done.Wait()
}

View File

@@ -0,0 +1,291 @@
package activity
import (
"net"
"net/netip"
"runtime"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn
}
func newMockEndpointManager() *mockEndpointManager {
return &mockEndpointManager{
endpoints: make(map[netip.Addr]net.Conn),
}
}
func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
m.endpoints[fakeIP] = conn
}
func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) {
delete(m.endpoints, fakeIP)
}
func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn {
return m.endpoints[fakeIP]
}
// MockWGIfaceBind mocks WgInterface with bind support
type MockWGIfaceBind struct {
endpointMgr *mockEndpointManager
}
func (m *MockWGIfaceBind) RemovePeer(string) error {
return nil
}
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
func (m *MockWGIfaceBind) IsUserspaceBind() bool {
return true
}
func (m *MockWGIfaceBind) Address() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
}
}
func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr
}
func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
expectedFakeIP := netip.MustParseAddr("127.2.0.2")
conn := mockEndpointMgr.GetEndpoint(expectedFakeIP)
require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager")
_, ok := conn.(*lazyConn)
assert.True(t, ok, "Registered endpoint should be a lazyConn")
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
}
func TestBindListener_ActivityDetection(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
activityDetected := make(chan struct{})
go func() {
listener.ReadPackets()
close(activityDetected)
}()
fakeIP := listener.fakeIP
conn := mockEndpointMgr.GetEndpoint(fakeIP)
require.NotNil(t, conn, "Endpoint should be registered")
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case <-activityDetected:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity detection")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection")
}
func TestBindListener_Close(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
fakeIP := listener.fakeIP
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close")
}
func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
mgr := NewManager(mockIface)
defer mgr.Close()
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
err := mgr.MonitorPeerActivity(cfg)
require.NoError(t, err)
listener, exists := mgr.GetPeerListener(cfg.PeerConnID)
require.True(t, exists, "Peer listener should be found")
bindListener, ok := listener.(*BindListener)
require.True(t, ok, "Listener should be BindListener, got %T", listener)
fakeIP := bindListener.fakeIP
conn := mockEndpointMgr.GetEndpoint(fakeIP)
require.NotNil(t, conn, "Endpoint should be registered")
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity")
}
func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer1 := &MocPeer{PeerID: "testPeer1"}
peer2 := &MocPeer{PeerID: "testPeer2"}
mgr := NewManager(mockIface)
defer mgr.Close()
cfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
cfg2 := lazyconn.PeerConfig{
PublicKey: peer2.PeerID,
PeerConnID: peer2.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")},
Log: log.WithField("peer", "testPeer2"),
}
err := mgr.MonitorPeerActivity(cfg1)
require.NoError(t, err)
err = mgr.MonitorPeerActivity(cfg2)
require.NoError(t, err)
listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID)
require.True(t, exists, "Peer1 listener should be found")
bindListener1 := listener1.(*BindListener)
listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID)
require.True(t, exists, "Peer2 listener should be found")
bindListener2 := listener2.(*BindListener)
conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP)
require.NotNil(t, conn1, "Peer1 endpoint should be registered")
_, err = conn1.Write([]byte{0x01})
require.NoError(t, err)
conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP)
require.NotNil(t, conn2, "Peer2 endpoint should be registered")
_, err = conn2.Write([]byte{0x02})
require.NoError(t, err)
receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ {
select {
case peerConnID := <-mgr.OnActivityChan:
receivedPeers[peerConnID] = true
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications")
}
}
assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received")
assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received")
}

View File

@@ -1,41 +0,0 @@
package activity
import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestNewListener(t *testing.T) {
peer := &MocPeer{
PeerID: "examplePublicKey1",
}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
l, err := NewListener(MocWGIface{}, cfg)
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
chanClosed := make(chan struct{})
go func() {
defer close(chanClosed)
l.ReadPackets()
}()
time.Sleep(1 * time.Second)
l.Close()
select {
case <-chanClosed:
case <-time.After(time.Second):
}
}

View File

@@ -11,26 +11,27 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
) )
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking // UDPListener uses UDP sockets for activity detection in kernel mode.
type Listener struct { type UDPListener struct {
wgIface WgInterface wgIface WgInterface
peerCfg lazyconn.PeerConfig peerCfg lazyconn.PeerConfig
conn *net.UDPConn conn *net.UDPConn
endpoint *net.UDPAddr endpoint *net.UDPAddr
done sync.Mutex done sync.Mutex
isClosed atomic.Bool // use to avoid error log when closing the listener isClosed atomic.Bool
} }
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { // NewUDPListener creates a listener that detects activity via UDP socket reads.
d := &Listener{ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) {
d := &UDPListener{
wgIface: wgIface, wgIface: wgIface,
peerCfg: cfg, peerCfg: cfg,
} }
conn, err := d.newConn() conn, err := d.newConn()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to creating activity listener: %v", err) return nil, fmt.Errorf("create UDP connection: %v", err)
} }
d.conn = conn d.conn = conn
d.endpoint = conn.LocalAddr().(*net.UDPAddr) d.endpoint = conn.LocalAddr().(*net.UDPAddr)
@@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error
if err := d.createEndpoint(); err != nil { if err := d.createEndpoint(); err != nil {
return nil, err return nil, err
} }
d.done.Lock() d.done.Lock()
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String()) cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String())
return d, nil return d, nil
} }
func (d *Listener) ReadPackets() { // ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
func (d *UDPListener) ReadPackets() {
for { for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil { if err != nil {
@@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() {
} }
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.removeEndpoint(); err != nil { if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
} }
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection" // Ignore close error as it may return "use of closed network connection" if already closed.
_ = d.conn.Close()
d.done.Unlock() d.done.Unlock()
} }
func (d *Listener) Close() { // Close stops the listener and cleans up resources.
func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true) d.isClosed.Store(true)
@@ -82,16 +87,12 @@ func (d *Listener) Close() {
d.done.Lock() d.done.Lock()
} }
func (d *Listener) removeEndpoint() error { func (d *UDPListener) createEndpoint() error {
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}
func (d *Listener) createEndpoint() error {
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String()) d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil) return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
} }
func (d *Listener) newConn() (*net.UDPConn, error) { func (d *UDPListener) newConn() (*net.UDPConn, error) {
addr := &net.UDPAddr{ addr := &net.UDPAddr{
Port: 0, Port: 0,
IP: listenIP, IP: listenIP,

View File

@@ -0,0 +1,110 @@
package activity
import (
"net"
"net/netip"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestUDPListener_Creation(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
require.NotNil(t, listener.conn)
require.NotNil(t, listener.endpoint)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
}
func TestUDPListener_ActivityDetection(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
activityDetected := make(chan struct{})
go func() {
listener.ReadPackets()
close(activityDetected)
}()
conn, err := net.Dial("udp", listener.conn.LocalAddr().String())
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case <-activityDetected:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity detection")
}
}
func TestUDPListener_Close(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed")
}

View File

@@ -1,21 +1,32 @@
package activity package activity
import ( import (
"errors"
"net" "net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
) )
// listener defines the contract for activity detection listeners.
type listener interface {
ReadPackets()
Close()
}
type WgInterface interface { type WgInterface interface {
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
} }
type Manager struct { type Manager struct {
@@ -23,7 +34,7 @@ type Manager struct {
wgIface WgInterface wgIface WgInterface
peers map[peerid.ConnID]*Listener peers map[peerid.ConnID]listener
done chan struct{} done chan struct{}
mu sync.Mutex mu sync.Mutex
@@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager {
m := &Manager{ m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1), OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface, wgIface: wgIface,
peers: make(map[peerid.ConnID]*Listener), peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}), done: make(chan struct{}),
} }
return m return m
@@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
return nil return nil
} }
listener, err := NewListener(m.wgIface, peerCfg) listener, err := m.createListener(peerCfg)
if err != nil { if err != nil {
return err return err
} }
m.peers[peerCfg.PeerConnID] = listener
m.peers[peerCfg.PeerConnID] = listener
go m.waitForTraffic(listener, peerCfg.PeerConnID) go m.waitForTraffic(listener, peerCfg.PeerConnID)
return nil return nil
} }
func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) {
if !m.wgIface.IsUserspaceBind() {
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is only used on Windows and JS platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// BindListener bypasses this by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider)
if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
}
return NewBindListener(m.wgIface, provider.GetBind(), peerCfg)
}
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) { func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -82,8 +115,8 @@ func (m *Manager) Close() {
} }
} }
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) { func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
listener.ReadPackets() l.ReadPackets()
m.mu.Lock() m.mu.Lock()
if _, ok := m.peers[peerConnID]; !ok { if _, ok := m.peers[peerConnID]; !ok {

View File

@@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
) )
@@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error {
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil return nil
} }
// Add this method to the Manager struct func (m MocWGIface) IsUserspaceBind() bool {
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { return false
}
func (m MocWGIface) Address() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
}
}
// GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
listener, exists := m.peers[peerConnID] l, exists := m.peers[peerConnID]
return listener, exists return l, exists
} }
func TestManager_MonitorPeerActivity(t *testing.T) { func TestManager_MonitorPeerActivity(t *testing.T) {
@@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
t.Fatalf("peer listener not found") t.Fatalf("peer listener not found")
} }
if err := trigger(listener.conn.LocalAddr().String()); err != nil { // Get the UDP listener's address for triggering
udpListener, ok := listener.(*UDPListener)
if !ok {
t.Fatalf("expected UDPListener")
}
if err := trigger(udpListener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err) t.Fatalf("failed to trigger activity: %v", err)
} }
@@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) {
t.Fatalf("failed to monitor peer activity: %v", err) t.Fatalf("failed to monitor peer activity: %v", err)
} }
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID)
udpListener, _ := listener.(*UDPListener)
addr := udpListener.conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
@@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("peer listener for peer1 not found") t.Fatalf("peer listener for peer1 not found")
} }
if err := trigger(listener.conn.LocalAddr().String()); err != nil { udpListener1, _ := listener.(*UDPListener)
if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err) t.Fatalf("failed to trigger activity: %v", err)
} }
@@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("peer listener for peer2 not found") t.Fatalf("peer listener for peer2 not found")
} }
if err := trigger(listener.conn.LocalAddr().String()); err != nil { udpListener2, _ := listener.(*UDPListener)
if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err) t.Fatalf("failed to trigger activity: %v", err)
} }

View File

@@ -7,6 +7,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
) )
@@ -14,5 +15,6 @@ type WGIface interface {
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool IsUserspaceBind() bool
Address() wgaddr.Address
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
} }

View File

@@ -10,10 +10,10 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/dns"
) )
type rcvChan chan *types.EventFields type rcvChan chan *types.EventFields
@@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection // check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { if !l.dnsCollection.Load() && event.Protocol == types.UDP &&
(event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) {
return false return false
} }

View File

@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() { if !isForceRelayed() {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
} }
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher) conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
@@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() {
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
changed := conn.statusICE.Get() != worker.StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay { if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {

View File

@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return return
} }
onNewOffeChan := make(chan struct{}) onNewOfferChan := make(chan struct{})
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{} onNewOfferChan <- struct{}{}
}) })
conn.OnRemoteOffer(OfferAnswer{ conn.OnRemoteOffer(OfferAnswer{
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer cancel() defer cancel()
select { select {
case <-onNewOffeChan: case <-onNewOfferChan:
// success // success
case <-ctx.Done(): case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out") t.Error("expected to receive a new offer notification, but timed out")
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return return
} }
onNewOffeChan := make(chan struct{}) onNewOfferChan := make(chan struct{})
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{} onNewOfferChan <- struct{}{}
}) })
conn.OnRemoteAnswer(OfferAnswer{ conn.OnRemoteAnswer(OfferAnswer{
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer cancel() defer cancel()
select { select {
case <-onNewOffeChan: case <-onNewOfferChan:
// success // success
case <-ctx.Done(): case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out") t.Error("expected to receive a new offer notification, but timed out")

View File

@@ -0,0 +1,20 @@
package guard
import (
"os"
"strconv"
"time"
)
const (
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
)
func GetICEMonitorPeriod() time.Duration {
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
return defaultCandidatesMonitorPeriod
}

View File

@@ -16,8 +16,8 @@ import (
) )
const ( const (
candidatesMonitorPeriod = 5 * time.Minute defaultCandidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second candidateGatheringTimeout = 5 * time.Second
) )
type ICEMonitor struct { type ICEMonitor struct {
@@ -25,16 +25,19 @@ type ICEMonitor struct {
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config iceConfig icemaker.Config
tickerPeriod time.Duration
currentCandidatesAddress []string currentCandidatesAddress []string
candidatesMu sync.Mutex candidatesMu sync.Mutex
} }
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
log.Debugf("prepare ICE monitor with period: %s", period)
cm := &ICEMonitor{ cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1), ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
iceConfig: config, iceConfig: config,
tickerPeriod: period,
} }
return cm return cm
} }
@@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
return return
} }
ticker := time.NewTicker(candidatesMonitorPeriod) // Initial check to populate the candidates for later comparison
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
log.Warnf("Failed to check initial ICE candidates: %v", err)
}
ticker := time.NewTicker(cm.tickerPeriod)
defer ticker.Stop() defer ticker.Stop()
for { for {

View File

@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged) go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected) w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -44,13 +44,19 @@ type OfferAnswer struct {
} }
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
config ConnConfig config ConnConfig
signaler *Signaler signaler *Signaler
ice *WorkerICE ice *WorkerICE
relay *WorkerRelay relay *WorkerRelay
onNewOfferListeners []*OfferListener // relayListener is not blocking because the listener is using a goroutine to process the messages
// and it will only keep the latest message if multiple offers are received in a short time
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
// and also to avoid processing old offers if multiple offers are received in a short time
// the listener will always process the latest offer
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
@@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
} }
} }
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
l := NewOfferListener(offer) h.relayListener = NewAsyncOfferListener(offer)
h.onNewOfferListeners = append(h.onNewOfferListeners, l) }
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.iceListener = offer
} }
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
if err := h.sendAnswer(); err != nil { if err := h.sendAnswer(); err != nil {
h.log.Errorf("failed to send remote offer confirmation: %s", err) h.log.Errorf("failed to send remote offer confirmation: %s", err)
continue continue
} }
for _, listener := range h.onNewOfferListeners {
listener.Notify(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners { if h.relayListener != nil {
listener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers") h.log.Infof("stop listening for remote offers and answers")

View File

@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
return oa.SessionID.String() return oa.SessionID.String()
} }
type OfferListener struct { type AsyncOfferListener struct {
fn callbackFunc fn callbackFunc
running bool running bool
latest *OfferAnswer latest *OfferAnswer
mu sync.Mutex mu sync.Mutex
} }
func NewOfferListener(fn callbackFunc) *OfferListener { func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
return &OfferListener{ return &AsyncOfferListener{
fn: fn, fn: fn,
} }
} }
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock() o.mu.Lock()
defer o.mu.Unlock() defer o.mu.Unlock()

View File

@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
runChan <- struct{}{} runChan <- struct{}{}
} }
hl := NewOfferListener(longRunningFn) hl := NewAsyncOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer)

View File

@@ -18,4 +18,5 @@ type WGIface interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address Address() wgaddr.Address
RemoveEndpointAddress(key string) error
} }

View File

@@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString()) w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agentConnecting { if w.agent != nil || w.agentConnecting {
w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.agent != nil {
// backward compatibility with old clients that do not send session ID // backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil { if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer") w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return return
} }
if w.remoteSessionID == *remoteOfferAnswer.SessionID { if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString()) w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return return
} }
w.log.Debugf("agent already exists, recreate the connection") w.log.Debugf("agent already exists, recreate the connection")
@@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if err := w.agent.Close(); err != nil { if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil w.agent = nil
} }
@@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
preferredCandidateTypes = icemaker.CandidateTypes() preferredCandidateTypes = icemaker.CandidateTypes()
} }
w.log.Debugf("recreate ICE agent") if remoteOfferAnswer.SessionID != nil {
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
}
dialerCtx, dialerCancel := context.WithCancel(w.ctx) dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil { if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err) w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return return
} }
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel w.agentDialerCancel = dialerCancel
w.agentConnecting = true w.agentConnecting = true
w.muxAgent.Unlock() if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
} else {
w.remoteSessionID = ""
}
go w.connect(dialerCtx, agent, remoteOfferAnswer) go w.connect(dialerCtx, agent, remoteOfferAnswer)
} }
@@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.muxAgent.Lock() w.muxAgent.Lock()
w.agentConnecting = false w.agentConnecting = false
w.lastSuccess = time.Now() w.lastSuccess = time.Now()
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
}
w.muxAgent.Unlock() w.muxAgent.Unlock()
// todo: the potential problem is a race between the onConnectionStateChange // todo: the potential problem is a race between the onConnectionStateChange
@@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
} }
w.muxAgent.Lock() w.muxAgent.Lock()
// todo review does it make sense to generate new session ID all the time when w.agent==agent
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent { if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil w.agent = nil
w.agentConnecting = false w.agentConnecting = false
w.remoteSessionID = ""
} }
w.muxAgent.Unlock() w.muxAgent.Unlock()
} }
@@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority // notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected()
} }
w.closeAgent(agent, dialerCancel)
default: default:
return return
} }

View File

@@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ config := &Config{
// defaults to false only for new (post 0.26) configurations // defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(), ServerSSHAllowed: util.False(),
WgPort: iface.DefaultWgPort,
} }
if _, err := config.apply(input); err != nil { if _, err := config.apply(input); err != nil {

View File

@@ -5,11 +5,14 @@ import (
"errors" "errors"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) {
} }
} }
func TestNewProfileDefaults(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
})
require.NoError(t, err, "should create new config")
assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default")
assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default")
assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated")
assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated")
assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default")
assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820")
assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default")
assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default")
assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set")
assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set")
assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults")
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS")
assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS")
}
}
func TestWireguardPortZeroExplicit(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
// Create a new profile with explicit port 0 (random port)
explicitZero := 0
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
WireguardPort: &explicitZero,
})
require.NoError(t, err, "should create config with explicit port 0")
assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user")
// Verify it persists
readConfig, err := GetConfig(configPath)
require.NoError(t, err)
assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file")
}
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
tests := []struct {
name string
wireguardPort *int
expectedPort int
description string
}{
{
name: "no port specified uses default",
wireguardPort: nil,
expectedPort: iface.DefaultWgPort,
description: "When user doesn't specify port, default to 51820",
},
{
name: "explicit zero for random port",
wireguardPort: func() *int { v := 0; return &v }(),
expectedPort: 0,
description: "When user explicitly sets 0, use 0 for random port",
},
{
name: "explicit custom port",
wireguardPort: func() *int { v := 52000; return &v }(),
expectedPort: 52000,
description: "When user sets custom port, use that port",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
WireguardPort: tt.wireguardPort,
})
require.NoError(t, err, tt.description)
assert.Equal(t, tt.expectedPort, config.WgPort, tt.description)
})
}
}
func TestUpdateOldManagementURL(t *testing.T) { func TestUpdateOldManagementURL(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -2,6 +2,8 @@ package relay
import ( import (
"context" "context"
"crypto/sha256"
"errors"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@@ -15,6 +17,15 @@ import (
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const (
DefaultCacheTTL = 20 * time.Second
probeTimeout = 6 * time.Second
)
var (
ErrCheckInProgress = errors.New("probe check is already in progress")
)
// ProbeResult holds the info about the result of a relay probe request // ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct { type ProbeResult struct {
URI string URI string
@@ -22,8 +33,164 @@ type ProbeResult struct {
Addr string Addr string
} }
type StunTurnProbe struct {
cacheResults []ProbeResult
cacheTimestamp time.Time
cacheKey string
cacheTTL time.Duration
probeInProgress bool
probeDone chan struct{}
mu sync.Mutex
}
func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
return &StunTurnProbe{
cacheTTL: cacheTTL,
}
}
func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if p.probeInProgress {
doneChan := p.probeDone
p.mu.Unlock()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-doneChan:
return p.getCachedResults(cacheKey, stuns, turns)
}
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
p.mu.Unlock()
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
return p.getCachedResults(cacheKey, stuns, turns)
}
// ProbeAll probes all given servers asynchronously and returns the results
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if results := p.checkCache(cacheKey); results != nil {
p.mu.Unlock()
return results
}
if p.probeInProgress {
p.mu.Unlock()
return createErrorResults(stuns, turns)
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
log.Infof("started new probe for STUN, TURN servers")
go func() {
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
}()
p.mu.Unlock()
timer := time.NewTimer(1300 * time.Millisecond)
defer timer.Stop()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-probeDone:
// when the probe is return fast, return the results right away
return p.getCachedResults(cacheKey, stuns, turns)
case <-timer.C:
// if the probe takes longer than 1.3s, return error results to avoid blocking
return createErrorResults(stuns, turns)
}
}
func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
age := time.Since(p.cacheTimestamp)
if age < p.cacheTTL {
results := append([]ProbeResult(nil), p.cacheResults...)
log.Debugf("returning cached probe results (age: %v)", age)
return results
}
}
return nil
}
func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
p.mu.Lock()
defer p.mu.Unlock()
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
return append([]ProbeResult(nil), p.cacheResults...)
}
return createErrorResults(stuns, turns)
}
func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
defer func() {
p.mu.Lock()
p.probeInProgress = false
p.mu.Unlock()
}()
results := make([]ProbeResult, len(stuns)+len(turns))
var wg sync.WaitGroup
for i, uri := range stuns {
wg.Add(1)
go func(idx int, stunURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = stunURI.String()
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
}(i, uri)
}
stunOffset := len(stuns)
for i, uri := range turns {
wg.Add(1)
go func(idx int, turnURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = turnURI.String()
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
}(stunOffset+i, uri)
}
wg.Wait()
p.mu.Lock()
p.cacheResults = results
p.cacheTimestamp = time.Now()
p.cacheKey = cacheKey
p.mu.Unlock()
log.Debug("Stored new probe results in cache")
}
// ProbeSTUN tries binding to the given STUN uri and acquiring an address // ProbeSTUN tries binding to the given STUN uri and acquiring an address
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() { defer func() {
if probeErr != nil { if probeErr != nil {
log.Debugf("stun probe error from %s: %s", uri, probeErr) log.Debugf("stun probe error from %s: %s", uri, probeErr)
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
} }
// ProbeTURN tries allocating a session from the given TURN URI // ProbeTURN tries allocating a session from the given TURN URI
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() { defer func() {
if probeErr != nil { if probeErr != nil {
log.Debugf("turn probe error from %s: %s", uri, probeErr) log.Debugf("turn probe error from %s: %s", uri, probeErr)
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
return relayConn.LocalAddr().String(), nil return relayConn.LocalAddr().String(), nil
} }
// ProbeAll probes all given servers asynchronously and returns the results func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
func ProbeAll( total := len(stuns) + len(turns)
ctx context.Context, results := make([]ProbeResult, total)
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
relays []*stun.URI,
) []ProbeResult {
results := make([]ProbeResult, len(relays))
var wg sync.WaitGroup allURIs := append(append([]*stun.URI{}, stuns...), turns...)
for i, uri := range relays { for i, uri := range allURIs {
ctx, cancel := context.WithTimeout(ctx, 6*time.Second) results[i] = ProbeResult{
defer cancel() URI: uri.String(),
Err: ErrCheckInProgress,
wg.Add(1) }
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
} }
wg.Wait()
return results return results
} }
func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
h := sha256.New()
for _, uri := range stuns {
h.Write([]byte(uri.String()))
}
for _, uri := range turns {
h.Write([]byte(uri.String()))
}
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -1,6 +1,7 @@
package common package common
import ( import (
"sync/atomic"
"time" "time"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
@@ -25,4 +26,5 @@ type HandlerParams struct {
UseNewDNSRoute bool UseNewDNSRoute bool
Firewall manager.Manager Firewall manager.Manager
FakeIPManager *fakeip.Manager FakeIPManager *fakeip.Manager
ForwarderPort *atomic.Uint32
} }

View File

@@ -8,6 +8,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@@ -18,7 +19,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/common"
@@ -55,6 +55,7 @@ type DnsInterceptor struct {
peerStore *peerstore.Store peerStore *peerstore.Store
firewall firewall.Manager firewall firewall.Manager
fakeIPManager *fakeip.Manager fakeIPManager *fakeip.Manager
forwarderPort *atomic.Uint32
} }
func New(params common.HandlerParams) *DnsInterceptor { func New(params common.HandlerParams) *DnsInterceptor {
@@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor {
firewall: params.Firewall, firewall: params.Firewall,
fakeIPManager: params.FakeIPManager, fakeIPManager: params.FakeIPManager,
interceptedDomains: make(domainMap), interceptedDomains: make(domainMap),
forwarderPort: params.ForwarderPort,
} }
} }
@@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel() defer cancel()

View File

@@ -10,6 +10,7 @@ import (
"runtime" "runtime"
"slices" "slices"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -23,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client" "github.com/netbirdio/netbird/client/internal/routemanager/client"
@@ -54,6 +56,7 @@ type Manager interface {
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
SetFirewall(firewall.Manager) error SetFirewall(firewall.Manager) error
SetDNSForwarderPort(port uint16)
Stop(stateManager *statemanager.Manager) Stop(stateManager *statemanager.Manager)
} }
@@ -101,12 +104,13 @@ type DefaultManager struct {
disableServerRoutes bool disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler activeRoutes map[route.HAUniqueID]client.RouteHandler
fakeIPManager *fakeip.Manager fakeIPManager *fakeip.Manager
dnsForwarderPort atomic.Uint32
} }
func NewManager(config ManagerConfig) *DefaultManager { func NewManager(config ManagerConfig) *DefaultManager {
mCTX, cancel := context.WithCancel(config.Context) mCTX, cancel := context.WithCancel(config.Context)
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier) sysOps := systemops.New(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil { if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name()) nbnet.SetVPNInterfaceName(config.WGInterface.Name())
@@ -130,6 +134,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
disableServerRoutes: config.DisableServerRoutes, disableServerRoutes: config.DisableServerRoutes,
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
} }
dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort))
useNoop := netstack.IsEnabled() || config.DisableClientRoutes useNoop := netstack.IsEnabled() || config.DisableClientRoutes
dm.setupRefCounters(useNoop) dm.setupRefCounters(useNoop)
@@ -270,6 +275,11 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
return nil return nil
} }
// SetDNSForwarderPort sets the DNS forwarder port for route handlers
func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
m.dnsForwarderPort.Store(uint32(port))
}
// Stop stops the manager watchers and clean firewall rules // Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop() m.stop()
@@ -345,6 +355,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
UseNewDNSRoute: m.useNewDNSRoute, UseNewDNSRoute: m.useNewDNSRoute,
Firewall: m.firewall, Firewall: m.firewall,
FakeIPManager: m.fakeIPManager, FakeIPManager: m.fakeIPManager,
ForwarderPort: &m.dnsForwarderPort,
} }
handler := client.HandlerFromRoute(params) handler := client.HandlerFromRoute(params)
if err := handler.AddRoute(m.ctx); err != nil { if err := handler.AddRoute(m.ctx); err != nil {

View File

@@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error {
panic("implement me") panic("implement me")
} }
// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface
func (m *MockManager) SetDNSForwarderPort(port uint16) {
}
// Stop mock implementation of Stop from Manager interface // Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop(stateManager *statemanager.Manager) { func (m *MockManager) Stop(stateManager *statemanager.Manager) {
if m.StopFunc != nil { if m.StopFunc != nil {

View File

@@ -0,0 +1,8 @@
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
package systemops
// FlushMarkedRoutes is a no-op on non-BSD platforms.
func (r *SysOps) FlushMarkedRoutes() error {
return nil
}

View File

@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
sysops := NewSysOps(nil, nil) sysOps := New(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
sysops.refCounter.LoadData((*ExclusionCounter)(s)) sysOps.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush() return sysOps.refCounter.Flush()
} }
func (s *ShutdownState) MarshalJSON() ([]byte, error) { func (s *ShutdownState) MarshalJSON() ([]byte, error) {

View File

@@ -83,7 +83,7 @@ type SysOps struct {
localSubnetsCacheTime time.Time localSubnetsCacheTime time.Time
} }
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier, notifier: notifier,

View File

@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
_, intf = setupDummyInterface(t) _, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf} nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil) r := New(nil, nil)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
nexthop := Nexthop{netip.Addr{}, netIntf} nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil) r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop) err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table") require.NoError(t, err, "Failed to add route to table")

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err) require.NoError(t, err)
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err) require.NoError(t, err)
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close()) assert.NoError(t, wgInterface.Close())
}) })
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")

View File

@@ -7,19 +7,39 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"strconv" "strconv"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/route" "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
)
var routeProtoFlag int
func init() {
switch os.Getenv(envRouteProtoFlag) {
case "2":
routeProtoFlag = unix.RTF_PROTO2
case "3":
routeProtoFlag = unix.RTF_PROTO3
default:
routeProtoFlag = unix.RTF_PROTO1
}
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
@@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
return r.cleanupRefCounter(stateManager) return r.cleanupRefCounter(stateManager)
} }
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
func (r *SysOps) FlushMarkedRoutes() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("Skipping route flush: %v", err)
continue
}
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
continue
}
nexthop := Nexthop{
IP: routeInfo.Gw,
Intf: routeInfo.Interface,
}
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
continue
}
flushedCount++
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
}
if flushedCount > 0 {
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_ADD, prefix, nexthop) return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
} }
@@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{ msg = &route.RouteMessage{
Type: action, Type: action,
Flags: unix.RTF_UP, Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Seq: r.getSeq(), Seq: r.getSeq(),
} }

View File

@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
data, err := os.ReadFile(m.filePath) data, err := os.ReadFile(m.filePath)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist") log.Debugf("state file %s does not exist", m.filePath)
return nil, nil // nolint:nilnil return nil, nil // nolint:nilnil
} }
return nil, fmt.Errorf("read state file: %w", err) return nil, fmt.Errorf("read state file: %w", err)

View File

@@ -0,0 +1,59 @@
package winregistry
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows/registry"
)
var (
advapi = syscall.NewLazyDLL("advapi32.dll")
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
)
const (
// Registry key options
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
// Registry disposition values
regCreatedNewKey = 0x1
regOpenedExistingKey = 0x2
)
// CreateVolatileKey creates a volatile registry key named path under open key root.
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
// The access parameter specifies the access rights for the key to be created.
//
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
// This provides automatic cleanup without requiring manual registry maintenance.
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
pathPtr, err := syscall.UTF16PtrFromString(path)
if err != nil {
return 0, false, err
}
var (
handle syscall.Handle
disposition uint32
)
ret, _, _ := regCreateKeyExW.Call(
uintptr(root),
uintptr(unsafe.Pointer(pathPtr)),
0, // reserved
0, // class
uintptr(regOptionVolatile), // options - volatile key
uintptr(access), // desired access
0, // security attributes
uintptr(unsafe.Pointer(&handle)),
uintptr(unsafe.Pointer(&disposition)),
)
if ret != 0 {
return 0, false, syscall.Errno(ret)
}
return registry.Key(handle), disposition == regOpenedExistingKey, nil
}

View File

@@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.CustomDNSAddress = []byte{} config.CustomDNSAddress = []byte{}
} }
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
if msg.DnsRouteInterval != nil {
interval := msg.DnsRouteInterval.AsDuration()
config.DNSRouteInterval = &interval
}
config.RosenpassEnabled = msg.RosenpassEnabled config.RosenpassEnabled = msg.RosenpassEnabled
config.RosenpassPermissive = msg.RosenpassPermissive config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect config.DisableAutoConnect = msg.DisableAutoConnect
@@ -1050,10 +1057,7 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus { if msg.GetFullPeerStatus {
if msg.ShouldRunProbes { s.runProbes(msg.ShouldRunProbes)
s.runProbes()
}
fullStatus := s.statusRecorder.GetFullStatus() fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory() pbFullStatus.Events = s.statusRecorder.GetEventHistory()
@@ -1063,7 +1067,7 @@ func (s *Server) Status(
return &statusResponse, nil return &statusResponse, nil
} }
func (s *Server) runProbes() { func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil { if s.connectClient == nil {
return return
} }
@@ -1074,7 +1078,7 @@ func (s *Server) runProbes() {
} }
if time.Since(s.lastProbe) > probeThreshold { if time.Since(s.lastProbe) > probeThreshold {
if engine.RunHealthProbes() { if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now() s.lastProbe = time.Now()
} }
} }

View File

@@ -0,0 +1,298 @@
package server
import (
"context"
"os/user"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
)
// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config.
// This test uses reflection to detect when new fields are added but not handled in SetConfig.
func TestSetConfig_AllFieldsSaved(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origDefaultConfigPath := profilemanager.DefaultConfigPath
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.ConfigDirOverride = tempDir
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
profilemanager.DefaultConfigPath = origDefaultConfigPath
profilemanager.ConfigDirOverride = ""
})
currUser, err := user.Current()
require.NoError(t, err)
profName := "test-profile"
ic := profilemanager.ConfigInput{
ConfigPath: filepath.Join(tempDir, profName+".json"),
ManagementURL: "https://api.netbird.io:443",
}
_, err = profilemanager.UpdateOrCreateConfig(ic)
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
})
require.NoError(t, err)
ctx := context.Background()
s := New(ctx, "console", "", false, false)
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
disableAutoConnect := true
networkMonitor := true
disableClientRoutes := true
disableServerRoutes := true
disableDNS := true
disableFirewall := true
blockLANAccess := true
disableNotifications := true
lazyConnectionEnabled := true
blockInbound := true
mtu := int64(1280)
req := &proto.SetConfigRequest{
ProfileName: profName,
Username: currUser.Username,
ManagementUrl: "https://new-api.netbird.io:443",
AdminURL: "https://new-admin.netbird.io",
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
DisableAutoConnect: &disableAutoConnect,
NetworkMonitor: &networkMonitor,
DisableClientRoutes: &disableClientRoutes,
DisableServerRoutes: &disableServerRoutes,
DisableDns: &disableDNS,
DisableFirewall: &disableFirewall,
BlockLanAccess: &blockLANAccess,
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
DnsLabels: []string{"label1", "label2"},
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
}
_, err = s.SetConfig(ctx, req)
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()
require.NoError(t, err)
cfg, err := profilemanager.GetConfig(cfgPath)
require.NoError(t, err)
require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String())
require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String())
require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled)
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect)
require.NotNil(t, cfg.NetworkMonitor)
require.Equal(t, networkMonitor, *cfg.NetworkMonitor)
require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes)
require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes)
require.Equal(t, disableDNS, cfg.DisableDNS)
require.Equal(t, disableFirewall, cfg.DisableFirewall)
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
require.NotNil(t, cfg.DisableNotifications)
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
require.Equal(t, blockInbound, cfg.BlockInbound)
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
// IFaceBlackList contains defaults + extras
require.Contains(t, cfg.IFaceBlackList, "eth1")
require.Contains(t, cfg.IFaceBlackList, "eth2")
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
require.Equal(t, uint16(mtu), cfg.MTU)
verifyAllFieldsCovered(t, req)
}
// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest.
// If a new field is added to SetConfigRequest, this function will fail the test,
// forcing the developer to update both the SetConfig handler and this test.
func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
t.Helper()
metadataFields := map[string]bool{
"state": true, // protobuf internal
"sizeCache": true, // protobuf internal
"unknownFields": true, // protobuf internal
"Username": true, // metadata
"ProfileName": true, // metadata
"CleanNATExternalIPs": true, // control flag for clearing
"CleanDNSLabels": true, // control flag for clearing
}
expectedFields := map[string]bool{
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
}
val := reflect.ValueOf(req).Elem()
typ := val.Type()
var unexpectedFields []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldName := field.Name
if metadataFields[fieldName] {
continue
}
if !expectedFields[fieldName] {
unexpectedFields = append(unexpectedFields, fieldName)
}
}
if len(unexpectedFields) > 0 {
t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields)
}
}
// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest.
// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq.
func TestCLIFlags_MappedToSetConfig(t *testing.T) {
// Map of CLI flag names to their corresponding SetConfigRequest field names.
// This map must be updated when adding new config-related CLI flags.
flagToField := map[string]string{
"management-url": "ManagementUrl",
"admin-url": "AdminURL",
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",
"disable-auto-connect": "DisableAutoConnect",
"network-monitor": "NetworkMonitor",
"disable-client-routes": "DisableClientRoutes",
"disable-server-routes": "DisableServerRoutes",
"disable-dns": "DisableDns",
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
"extra-dns-labels": "DnsLabels",
"dns-router-interval": "DnsRouteInterval",
"mtu": "Mtu",
}
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
fieldsWithoutCLIFlags := map[string]bool{
"DisableNotifications": true, // Only settable via UI
}
// Get all SetConfigRequest fields to verify our map is complete.
req := &proto.SetConfigRequest{}
val := reflect.ValueOf(req).Elem()
typ := val.Type()
var unmappedFields []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldName := field.Name
// Skip protobuf internal fields and metadata fields.
if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" {
continue
}
if fieldName == "Username" || fieldName == "ProfileName" {
continue
}
if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" {
continue
}
// Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag.
mappedToCLI := false
for _, mappedField := range flagToField {
if mappedField == fieldName {
mappedToCLI = true
break
}
}
hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName]
if !mappedToCLI && !hasNoCLIFlag {
unmappedFields = append(unmappedFields, fieldName)
}
}
if len(unmappedFields) > 0 {
t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+
"Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+
"add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields)
}
t.Log("All SetConfigRequest fields are properly documented")
}

View File

@@ -10,7 +10,9 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
} }
// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
@@ -205,15 +206,18 @@ func mapPeers(
localICEEndpoint := "" localICEEndpoint := ""
remoteICEEndpoint := "" remoteICEEndpoint := ""
relayServerAddress := "" relayServerAddress := ""
connType := "P2P" connType := "-"
lastHandshake := time.Time{} lastHandshake := time.Time{}
transferReceived := int64(0) transferReceived := int64(0)
transferSent := int64(0) transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if pbPeerState.Relayed { if isPeerConnected {
connType = "Relayed" connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
} }
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
@@ -337,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
for _, relay := range overview.Relays.Details { for _, relay := range overview.Relays.Details {
available := "Available" available := "Available"
reason := "" reason := ""
if !relay.Available { if !relay.Available {
available = "Unavailable" if relay.Error == probeRelay.ErrCheckInProgress.Error() {
reason = fmt.Sprintf(", reason: %s", relay.Error) available = "Checking..."
} else {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
} }
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
} }
} else { } else {

View File

@@ -31,7 +31,6 @@ import (
"fyne.io/systray" "fyne.io/systray"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@@ -297,6 +296,8 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem
logFile string logFile string
wLoginURL fyne.Window wLoginURL fyne.Window
connectCancel context.CancelFunc
} }
type menuHandler struct { type menuHandler struct {
@@ -593,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
} }
} }
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) return nil, fmt.Errorf("get daemon client: %w", err)
return nil, err
} }
activeProf, err := s.profileManager.GetActiveProfile() activeProf, err := s.profileManager.GetActiveProfile()
if err != nil { if err != nil {
log.Errorf("get active profile: %v", err) return nil, fmt.Errorf("get active profile: %w", err)
return nil, err
} }
currUser, err := user.Current() currUser, err := user.Current()
@@ -611,84 +610,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
return nil, fmt.Errorf("get current user: %w", err) return nil, fmt.Errorf("get current user: %w", err)
} }
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ loginResp, err := conn.Login(ctx, &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name, ProfileName: &activeProf.Name,
Username: &currUser.Username, Username: &currUser.Username,
}) })
if err != nil { if err != nil {
log.Errorf("login to management URL with: %v", err) return nil, fmt.Errorf("login to management: %w", err)
return nil, err
} }
if loginResp.NeedsSSOLogin && openURL { if loginResp.NeedsSSOLogin && openURL {
err = s.handleSSOLogin(loginResp, conn) if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil {
if err != nil { return nil, fmt.Errorf("SSO login: %w", err)
log.Errorf("handle SSO login failed: %v", err)
return nil, err
} }
} }
return loginResp, nil return loginResp, nil
} }
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
err := open.Run(loginResp.VerificationURIComplete) if err := openURL(loginResp.VerificationURIComplete); err != nil {
if err != nil { return fmt.Errorf("open browser: %w", err)
log.Errorf("opening the verification uri in the browser failed: %v", err)
return err
} }
resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil { if err != nil {
log.Errorf("waiting sso login failed with: %v", err) return fmt.Errorf("wait for SSO login: %w", err)
return err
} }
if resp.Email != "" { if resp.Email != "" {
err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
Email: resp.Email, Email: resp.Email,
}) }); err != nil {
if err != nil { log.Debugf("failed to set profile state: %v", err)
log.Warnf("failed to set profile state: %v", err)
} else { } else {
s.mProfile.refresh() s.mProfile.refresh()
} }
} }
return nil return nil
} }
func (s *serviceClient) menuUpClick() error { func (s *serviceClient) menuUpClick(ctx context.Context) error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
systray.SetTemplateIcon(iconErrorMacOS, s.icError) systray.SetTemplateIcon(iconErrorMacOS, s.icError)
log.Errorf("get client: %v", err) return fmt.Errorf("get daemon client: %w", err)
return err
} }
_, err = s.login(true) _, err = s.login(ctx, true)
if err != nil { if err != nil {
log.Errorf("login failed with: %v", err) return fmt.Errorf("login: %w", err)
return err
} }
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) return fmt.Errorf("get status: %w", err)
return err
} }
if status.Status == string(internal.StatusConnected) { if status.Status == string(internal.StatusConnected) {
log.Warnf("already connected")
return nil return nil
} }
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
log.Errorf("up service: %v", err) return fmt.Errorf("start connection: %w", err)
return err
} }
return nil return nil
@@ -698,24 +684,20 @@ func (s *serviceClient) menuDownClick() error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) return fmt.Errorf("get daemon client: %w", err)
return err
} }
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) return fmt.Errorf("get status: %w", err)
return err
} }
if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) {
log.Warnf("already down")
return nil return nil
} }
if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("down service: %v", err) return fmt.Errorf("stop connection: %w", err)
return err
} }
return nil return nil
@@ -1354,7 +1336,13 @@ func (s *serviceClient) updateConfig() error {
} }
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. // showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
func (s *serviceClient) showLoginURL() { // It also starts a background goroutine that periodically checks if the client is already connected
// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is
// also cancelled when the window is closed.
func (s *serviceClient) showLoginURL() context.CancelFunc {
// create a cancellable context for the background check goroutine
ctx, cancel := context.WithCancel(s.ctx)
resIcon := fyne.NewStaticResource("netbird.png", iconAbout) resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
@@ -1363,6 +1351,8 @@ func (s *serviceClient) showLoginURL() {
s.wLoginURL.Resize(fyne.NewSize(400, 200)) s.wLoginURL.Resize(fyne.NewSize(400, 200))
s.wLoginURL.SetIcon(resIcon) s.wLoginURL.SetIcon(resIcon)
} }
// ensure goroutine is cancelled when the window is closed
s.wLoginURL.SetOnClosed(func() { cancel() })
// add a description label // add a description label
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.") label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
@@ -1374,7 +1364,7 @@ func (s *serviceClient) showLoginURL() {
return return
} }
resp, err := s.login(false) resp, err := s.login(ctx, false)
if err != nil { if err != nil {
log.Errorf("failed to fetch login URL: %v", err) log.Errorf("failed to fetch login URL: %v", err)
return return
@@ -1394,7 +1384,7 @@ func (s *serviceClient) showLoginURL() {
return return
} }
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
if err != nil { if err != nil {
log.Errorf("Waiting sso login failed with: %v", err) log.Errorf("Waiting sso login failed with: %v", err)
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
@@ -1402,7 +1392,7 @@ func (s *serviceClient) showLoginURL() {
} }
label.SetText("Re-authentication successful.\nReconnecting") label.SetText("Re-authentication successful.\nReconnecting")
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) log.Errorf("get service status: %v", err)
return return
@@ -1415,7 +1405,7 @@ func (s *serviceClient) showLoginURL() {
return return
} }
_, err = conn.Up(s.ctx, &proto.UpRequest{}) _, err = conn.Up(ctx, &proto.UpRequest{})
if err != nil { if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
log.Errorf("Reconnecting failed with: %v", err) log.Errorf("Reconnecting failed with: %v", err)
@@ -1443,10 +1433,46 @@ func (s *serviceClient) showLoginURL() {
) )
s.wLoginURL.SetContent(container.NewCenter(content)) s.wLoginURL.SetContent(container.NewCenter(content))
// start a goroutine to check connection status and close the window if connected
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
continue
}
if status.Status == string(internal.StatusConnected) {
if s.wLoginURL != nil {
s.wLoginURL.Close()
}
return
}
}
}
}()
s.wLoginURL.Show() s.wLoginURL.Show()
// return cancel func so callers can stop the background goroutine if desired
return cancel
} }
func openURL(url string) error { func openURL(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
var err error var err error
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":

View File

@@ -18,6 +18,7 @@ import (
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types" uptypes "github.com/netbirdio/netbird/upload-server/types"
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
return "", err return "", err
} }
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
log.Warnf("Failed to get post-up status: %v", err) log.Warnf("Failed to get post-up status: %v", err)
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err) return nil, fmt.Errorf("get client: %v", err)
} }
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err) log.Warnf("failed to get status for debug bundle: %v", err)
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }

View File

@@ -12,6 +12,8 @@ import (
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/systray" "fyne.io/systray"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
@@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) {
func (h *eventHandler) handleConnectClick() { func (h *eventHandler) handleConnectClick() {
h.client.mUp.Disable() h.client.mUp.Disable()
if h.client.connectCancel != nil {
h.client.connectCancel()
}
connectCtx, connectCancel := context.WithCancel(h.client.ctx)
h.client.connectCancel = connectCancel
go func() { go func() {
defer h.client.mUp.Enable() defer connectCancel()
if err := h.client.menuUpClick(); err != nil {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) if err := h.client.menuUpClick(connectCtx); err != nil {
st, ok := status.FromError(err)
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
log.Debugf("connect operation cancelled by user")
} else {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
log.Errorf("connect failed: %v", err)
}
}
if err := h.client.updateStatus(); err != nil {
log.Debugf("failed to update status after connect: %v", err)
} }
}() }()
} }
func (h *eventHandler) handleDisconnectClick() { func (h *eventHandler) handleDisconnectClick() {
h.client.mDown.Disable() h.client.mDown.Disable()
if h.client.connectCancel != nil {
log.Debugf("cancelling ongoing connect operation")
h.client.connectCancel()
h.client.connectCancel = nil
}
go func() { go func() {
defer h.client.mDown.Enable()
if err := h.client.menuDownClick(); err != nil { if err := h.client.menuDownClick(); err != nil {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon")) st, ok := status.FromError(err)
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
log.Errorf("disconnect failed: %v", err)
} else {
log.Debugf("disconnect cancelled or already disconnecting")
}
}
if err := h.client.updateStatus(); err != nil {
log.Debugf("failed to update status after disconnect: %v", err)
} }
}() }()
} }
@@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
} }
h.client.getSrvConfig() h.client.getSrvConfig()
return nil return nil
} }

View File

@@ -387,6 +387,7 @@ type subItem struct {
type profileMenu struct { type profileMenu struct {
mu sync.Mutex mu sync.Mutex
ctx context.Context ctx context.Context
serviceClient *serviceClient
profileManager *profilemanager.ProfileManager profileManager *profilemanager.ProfileManager
eventHandler *eventHandler eventHandler *eventHandler
profileMenuItem *systray.MenuItem profileMenuItem *systray.MenuItem
@@ -396,7 +397,7 @@ type profileMenu struct {
logoutSubItem *subItem logoutSubItem *subItem
profilesState []Profile profilesState []Profile
downClickCallback func() error downClickCallback func() error
upClickCallback func() error upClickCallback func(context.Context) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func() loadSettingsCallback func()
app fyne.App app fyne.App
@@ -404,12 +405,13 @@ type profileMenu struct {
type newProfileMenuArgs struct { type newProfileMenuArgs struct {
ctx context.Context ctx context.Context
serviceClient *serviceClient
profileManager *profilemanager.ProfileManager profileManager *profilemanager.ProfileManager
eventHandler *eventHandler eventHandler *eventHandler
profileMenuItem *systray.MenuItem profileMenuItem *systray.MenuItem
emailMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem
downClickCallback func() error downClickCallback func() error
upClickCallback func() error upClickCallback func(context.Context) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func() loadSettingsCallback func()
app fyne.App app fyne.App
@@ -418,6 +420,7 @@ type newProfileMenuArgs struct {
func newProfileMenu(args newProfileMenuArgs) *profileMenu { func newProfileMenu(args newProfileMenuArgs) *profileMenu {
p := profileMenu{ p := profileMenu{
ctx: args.ctx, ctx: args.ctx,
serviceClient: args.serviceClient,
profileManager: args.profileManager, profileManager: args.profileManager,
eventHandler: args.eventHandler, eventHandler: args.eventHandler,
profileMenuItem: args.profileMenuItem, profileMenuItem: args.profileMenuItem,
@@ -569,10 +572,19 @@ func (p *profileMenu) refresh() {
} }
} }
if err := p.upClickCallback(); err != nil { if p.serviceClient.connectCancel != nil {
p.serviceClient.connectCancel()
}
connectCtx, connectCancel := context.WithCancel(p.ctx)
p.serviceClient.connectCancel = connectCancel
if err := p.upClickCallback(connectCtx); err != nil {
log.Errorf("failed to handle up click after switching profile: %v", err) log.Errorf("failed to handle up click after switching profile: %v", err)
} }
connectCancel()
p.refresh() p.refresh()
p.loadSettingsCallback() p.loadSettingsCallback()
} }

View File

@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
} }
} }
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
return &tls.Config{ config := &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error { VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte var certChain [][]byte
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
return nil return nil
}, },
} }
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
if requiresCredSSP {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
} else {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS13
}
return config
} }

View File

@@ -6,11 +6,13 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/asn1" "encoding/asn1"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"sync" "sync"
"syscall/js" "syscall/js"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -19,18 +21,34 @@ const (
RDCleanPathVersion = 3390 RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local" RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws" RDCleanPathProxyScheme = "ws"
rdpDialTimeout = 15 * time.Second
GeneralErrorCode = 1
WSAETimedOut = 10060
WSAEConnRefused = 10061
WSAEConnAborted = 10053
WSAEConnReset = 10054
WSAEGenericError = 10050
) )
type RDCleanPathPDU struct { type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"` Version int64 `asn1:"tag:0,explicit"`
Error []byte `asn1:"tag:1,explicit,optional"` Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"` Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathErr struct {
ErrorCode int16 `asn1:"tag:0,explicit"`
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
WSALastError int16 `asn1:"tag:2,explicit,optional"`
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
} }
type RDCleanPathProxy struct { type RDCleanPathProxy struct {
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
destination := conn.destination destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil { if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err) log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
return return
} }
conn.rdpConn = rdpConn conn.rdpConn = rdpConn
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
_, err = rdpConn.Write(firstPacket) _, err = rdpConn.Write(firstPacket)
if err != nil { if err != nil {
log.Errorf("Failed to write first packet: %v", err) log.Errorf("Failed to write first packet: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return return
} }
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
n, err := rdpConn.Read(response) n, err := rdpConn.Read(response)
if err != nil { if err != nil {
log.Errorf("Failed to read X.224 response: %v", err) log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return return
} }
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
conn.wsHandlers.Call("send", uint8Array.Get("buffer")) conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
} }
} }
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func errorToWSACode(err error) int16 {
if err == nil {
return WSAEGenericError
}
var netErr *net.OpError
if errors.As(err, &netErr) && netErr.Timeout() {
return WSAETimedOut
}
if errors.Is(err, context.DeadlineExceeded) {
return WSAETimedOut
}
if errors.Is(err, context.Canceled) {
return WSAEConnAborted
}
if errors.Is(err, io.EOF) {
return WSAEConnReset
}
return WSAEGenericError
}
func newWSAError(err error) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
WSALastError: errorToWSACode(err),
},
}
}
func newHTTPError(statusCode int16) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
HTTPStatusCode: statusCode,
},
}
}

View File

@@ -3,6 +3,7 @@
package rdp package rdp
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/asn1" "encoding/asn1"
"io" "io"
@@ -11,11 +12,17 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
protocolSSL = 0x00000001
protocolHybridEx = 0x00000008
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion { if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, "Unsupported version") p.sendRDCleanPathError(conn, newHTTPError(400))
return return
} }
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
destination = pdu.Destination destination = pdu.Destination
} }
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil { if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err) log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, "Connection failed") p.sendRDCleanPathError(conn, newWSAError(err))
p.cleanupConnection(conn) p.cleanupConnection(conn)
return return
} }
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
p.setupTLSConnection(conn, pdu) p.setupTLSConnection(conn, pdu)
} }
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
const minResponseLength = 19
if len(x224Response) < minResponseLength {
return false, 0, false
}
// Per X.224 specification:
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
// x224Response[5] == 0xD0: X.224 Data TPDU code
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
return false, 0, false
}
if x224Response[11] == 0x02 {
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
return hasNLA, flags, true
}
return false, 0, false
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 { if len(pdu.X224ConnectionPDU) > 0 {
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil { if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err) log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224") p.sendRDCleanPathError(conn, newWSAError(err))
return return
} }
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
n, err := conn.rdpConn.Read(response) n, err := conn.rdpConn.Read(response)
if err != nil { if err != nil {
log.Errorf("Failed to read X.224 response: %v", err) log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response") p.sendRDCleanPathError(conn, newWSAError(err))
return return
} }
x224Response = response[:n] x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
} }
tlsConfig := p.getTLSConfigWithValidation(conn) requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
if detected {
if requiresCredSSP {
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
} else {
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
}
} else {
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
}
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConn := tls.Client(conn.rdpConn, tlsConfig) tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err) log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, "TLS handshake failed") p.sendRDCleanPathError(conn, newWSAError(err))
return return
} }
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
p.cleanupConnection(conn) p.cleanupConnection(conn)
} }
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
X224ConnectionPDU: response[:n],
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
} else {
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
}
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
<-conn.ctx.Done()
log.Debug("TCP connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu) data, err := asn1.Marshal(pdu)
if err != nil { if err != nil {
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
p.sendToWebSocket(conn, data) p.sendToWebSocket(conn, data)
} }
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
pdu := RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: []byte(errorMsg),
}
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte) msgChan := make(chan []byte)
errChan := make(chan error) errChan := make(chan error)

Some files were not shown because too many files have changed in this diff Show More