Compare commits

...

31 Commits

Author SHA1 Message Date
Viktor Liu
b2d7121695 Add logging around routes 2025-10-28 14:20:15 +01:00
Viktor Liu
3288c4414f Merge branch 'shutdown-block' into feature/detect-mac-wakeup 2025-10-27 23:20:25 +01:00
Viktor Liu
f3b0439211 Merge branch 'main' into feature/detect-mac-wakeup 2025-10-27 23:20:19 +01:00
Viktor Liu
ae801d77fb Block on all subsystems on shutdown 2025-10-27 23:15:47 +01:00
Pascal Fischer
4545ab9a52 [management] rewire account manager to permissions manager (#4673) 2025-10-27 22:59:35 +01:00
Bethuel Mmbaga
7f08983207 Include expired and routing peers in DNS record filtering (#4708) 2025-10-27 22:16:17 +03:00
Viktor Liu
eddea14521 [client] Clean up bsd routes independently of the state file (#4688) 2025-10-27 18:54:00 +01:00
Viktor Liu
b9ef214ea5 [client] Fix macOS state-based dns cleanup (#4701) 2025-10-27 18:35:32 +01:00
Bethuel Mmbaga
709e24eb6f [signal] Fix HTTP/WebSocket proxy not using custom certificates (#4644)
This pull request fixes a bug where the HTTP/WebSocket proxy server was not using custom TLS certificates when provided via --cert-file and --cert-key flags. Previously, only the gRPC server had TLS enabled with custom certificates, while the HTTP/WebSocket proxy ran without TLS.
2025-10-24 15:40:20 +03:00
Viktor Liu
bf83549db2 Merge branch 'main' into feature/detect-mac-wakeup 2025-10-23 17:09:06 +02:00
Viktor Liu
804a3871fe Merge branch 'fix-deprecated-grpc' into feature/detect-mac-wakeup 2025-10-23 17:08:12 +02:00
Viktor Liu
6654e2dbf7 [client] Fix active profile name in debug bundle (#4689) 2025-10-23 17:07:52 +02:00
Viktor Liu
64d1edce27 Merge branch 'bsd-route-cleanup' into feature/detect-mac-wakeup 2025-10-23 17:07:22 +02:00
Viktor Liu
bf0698e5aa Clean up bsd routes independently of the state file 2025-10-23 16:42:23 +02:00
Viktor Liu
fc15625963 Clean up failed conn hooks 2025-10-23 15:35:45 +02:00
Viktor Liu
a75dde33b9 Fix happy eyeballs for grpc 2025-10-23 13:14:37 +02:00
Bethuel Mmbaga
d80d47a469 [management] Add peer disapproval reason (#4468) 2025-10-22 12:46:22 +03:00
Maycon Santos
96f71ff1e1 [misc] Update tag name extraction in install.sh (#4677) 2025-10-21 19:23:11 +02:00
Viktor Liu
2fe2af38d2 [client] Clean up match domain reg entries between config changes (#4676) 2025-10-21 18:14:39 +02:00
Viktor Liu
bb46e438aa Add gw change polling and time drift detection (informational only) 2025-10-21 12:19:35 +02:00
Misha Bragin
cd9a867ad0 [client] Delete TURNConfig section from script (#4639) 2025-10-17 19:48:26 +02:00
Maycon Santos
0f9bfeff7c [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 2025-10-17 19:47:11 +02:00
Viktor Liu
f5301230bf [client] Fix status showing P2P without connection (#4661) 2025-10-17 13:31:15 +02:00
Viktor Liu
429d7d6585 [client] Support BROWSER env for login (#4654) 2025-10-17 11:10:16 +02:00
Viktor Liu
3cdb10cde7 [client] Remove rule squashing (#4653) 2025-10-17 11:09:39 +02:00
Zoltán Papp
11ba253ffb Merge branch 'main' into feature/detect-mac-wakeup 2025-10-16 17:16:19 +02:00
Zoltan Papp
14fe7c29cb Change log levels 2025-10-14 18:13:52 +02:00
Zoltan Papp
158f3aceff Handle better sleep period 2025-10-14 12:24:39 +02:00
Zoltan Papp
bfa776c155 Update client/internal/networkmonitor/check_change_darwin.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-14 12:18:00 +02:00
Zoltan Papp
885b5c68ad Update client/internal/networkmonitor/monitor.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-14 12:17:51 +02:00
Zoltan Papp
b1ebac795d Extend Darwin network monitoring with wakeup detection 2025-10-14 12:14:08 +02:00
61 changed files with 1220 additions and 838 deletions

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0
FROM alpine:3.22.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
}
return statusOutputString

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
@@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("")
if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil {
if err := openBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ""
}
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"

View File

@@ -4,12 +4,15 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"runtime"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
@@ -17,6 +20,9 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots"
)
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}))
}
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
conn, err := grpc.NewClient(
addr,
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
}

View File

@@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil

View File

@@ -29,11 +29,6 @@ type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
}
type protoMatch struct {
ips map[string]int
policyID []byte
}
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
@@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap)
rules := networkMap.FirewallRules
enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
// if TCP protocol rules not squashed and SSH enabled
// we add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP 22 port).
// If SSH enabled, add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
@@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID(
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic.
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
totalIPs++
}
}
in := map[mgmProto.RuleProtocol]*protoMatch{}
out := map[mgmProto.RuleProtocol]*protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if:
// 1. Any of the rule has DROP action type.
// 2. Any of rule contains Port.
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
}
// special case, when we receive this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return
}
ipset := protocols[r.Protocol].ips
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
}
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
}
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
protocolOrders := []mgmProto.RuleProtocol{
mgmProto.RuleProtocol_ALL,
mgmProto.RuleProtocol_ICMP,
mgmProto.RuleProtocol_TCP,
mgmProto.RuleProtocol_UDP,
}
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
// don't squash if :
// 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network.
continue
}
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
PolicyID: match.policyID,
})
squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
}
}
}
squash(in, mgmProto.RuleDirection_IN)
squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
if len(squashedRules) == 0 {
return networkMap.FirewallRules, squashedProtocols
}
var rules []*mgmProto.FirewallRule
// filter out rules which was squashed from final list
// if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
}
}
rules = append(rules, r)
}
return append(rules, squashedRules...), squashedProtocols
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)

View File

@@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
func TestDefaultManagerSquashRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
}
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
<-engineCtx.Done()
c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown()
backOff.Reset()
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
}
func (c *ConnectClient) Stop() error {
if c == nil {
return nil
engine := c.Engine()
if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
}
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
if c.engine == nil {
return nil
}
if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil
}

View File

@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states.
state.json: Anonymized client state dump containing netbird states for the active profile.
mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information.
block.prof: Block profiling information.
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
log.Debugf("Adding state file from: %s", path)
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {

View File

@@ -13,6 +13,7 @@ import (
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
}
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var (
searchDomains []string
matchDomains []string
)
err = s.recordSystemDNSSettings(true)
if err != nil {
if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}
if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
if err := s.addLocalDNS(); err != nil {
log.Warnf("failed to add local DNS: %v", err)
}
s.updateState(stateManager)
}
for _, dConf := range config.Domains {
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
}
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else {
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil {
return fmt.Errorf("add match domains: %w", err)
}
s.updateState(stateManager)
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 {
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil {
return fmt.Errorf("add search domains: %w", err)
}
s.updateState(stateManager)
if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
return nil
}
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (s *systemConfigurator) string() string {
return "scutil"
}
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("Unable to get system DNS configuration")
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
}
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
log.Info("Not enabling local DNS server")
return nil
}
if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
}
return nil

View File

@@ -0,0 +1,111 @@
//go:build !ios
package dns
import (
"context"
"net/netip"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
defer func() {
require.NoError(t, sm.Stop(context.Background()))
}()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
config := HostDNSConfig{
ServerIP: netip.MustParseAddr("100.64.0.1"),
ServerPort: 53,
RouteAll: true,
Domains: []DomainConfig{
{Domain: "example.com", MatchOnly: true},
},
}
err := configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
require.NoError(t, sm.PersistState(context.Background()))
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
defer func() {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}()
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
t.Logf("Key %s exists before cleanup", key)
}
}
sm2 := statemanager.New(stateFile)
sm2.RegisterState(&ShutdownState{})
err = sm2.LoadState(&ShutdownState{})
require.NoError(t, err)
state := sm2.GetState(&ShutdownState{})
if state == nil {
t.Skip("State not saved, skipping cleanup test")
}
shutdownState, ok := state.(*ShutdownState)
require.True(t, ok)
err = shutdownState.Cleanup()
require.NoError(t, err)
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}
func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
output, err := cmd.CombinedOutput()
if err != nil {
if strings.Contains(string(output), "No such key") {
return false, nil
}
return false, err
}
return !strings.Contains(string(output), "No such key"), nil
}
func removeTestDNSKey(key string) error {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
_, err := cmd.CombinedOutput()
return err
}

View File

@@ -17,6 +17,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/winregistry"
)
var (
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
r.updateState(stateManager)
var searchDomains, matchDomains []string
for _, dConf := range config.Domains {
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
}
if err := r.removeDNSMatchPolicies(); err != nil {
log.Errorf("cleanup old dns match policies: %s", err)
}
if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil {
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
}
r.nrptEntryCount = count
} else {
if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err)
}
r.nrptEntryCount = 0
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
r.updateState(stateManager)
if err := r.updateSearchDomains(searchDomains); err != nil {
return fmt.Errorf("update search domains: %w", err)
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil
}
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err)
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
return fmt.Errorf("remove existing dns policy: %w", err)
}
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
}
defer closer(regKey)

View File

@@ -0,0 +1,102 @@
package dns
import (
"fmt"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/registry"
)
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config5, nil)
require.NoError(t, err)
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after first config", i)
}
config2 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config2, nil)
require.NoError(t, err)
// Verify first 2 entries exist
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after second config", i)
}
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
}
func registryKeyExists(path string) (bool, error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
if err == registry.ErrNotExist {
return false, nil
}
return false, err
}
k.Close()
return true, nil
}
func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10}
_ = cfg.removeDNSMatchPolicies()
}

View File

@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelFunc
shutdownWg sync.WaitGroup
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
// This is different from ServiceEnable=false from management which completely disables the DNS service.
disableSys bool
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
// Stop stops the server
func (s *DefaultServer) Stop() {
s.ctxCancel()
s.shutdownWg.Wait()
s.mux.Lock()
defer s.mux.Unlock()
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.applyHostConfig()
s.shutdownWg.Add(1)
go func() {
// persist dns state right away
defer s.shutdownWg.Done()
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}

View File

@@ -7,6 +7,7 @@ import (
)
type ShutdownState struct {
CreatedKeys []string
}
func (s *ShutdownState) Name() string {
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create host manager: %w", err)
}
for _, key := range s.CreatedKeys {
manager.createdKeys[key] = struct{}{}
}
if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err)
}

View File

@@ -200,8 +200,10 @@ type Engine struct {
flowManager nftypes.FlowManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
wgIfaceMonitor *WGIfaceMonitor
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
shutdownWg sync.WaitGroup
// dns forwarder port
dnsFwdPort uint16
@@ -326,10 +328,6 @@ func (e *Engine) Stop() error {
e.cancel()
}
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
// stop flow manager after wg interface is gone
@@ -337,8 +335,6 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
@@ -349,12 +345,52 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
func (e *Engine) calculateShutdownTimeout() time.Duration {
peerCount := len(e.peerStore.PeersPubKey())
baseTimeout := 10 * time.Second
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
timeout := baseTimeout + perPeerTimeout
maxTimeout := 30 * time.Second
if timeout > maxTimeout {
timeout = maxTimeout
}
return timeout
}
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
@@ -484,14 +520,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
e.shutdownWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
e.triggerClientRestart()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
@@ -892,7 +928,9 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// blocking
err = e.sshServer.Start()
if err != nil {
@@ -950,7 +988,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
@@ -1368,7 +1408,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
e.syncMsgMux.Lock()
@@ -1724,8 +1766,10 @@ func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
)
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
// triggerClientRestart triggers a full client restart by cancelling the client context.
// Note: This does NOT just restart the engine - it cancels the entire client context,
// which causes the connect client's retry loop to create a completely new engine.
func (e *Engine) triggerClientRestart() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1747,7 +1791,9 @@ func (e *Engine) startNetworkMonitor() {
}
e.networkMonitor = networkmonitor.New()
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
@@ -1757,8 +1803,8 @@ func (e *Engine) startNetworkMonitor() {
return
}
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
log.Infof("Network monitor: detected network change, triggering client restart")
e.triggerClientRestart()
}()
}

View File

@@ -24,6 +24,7 @@ import (
// Manager handles netflow tracking and logging
type Manager struct {
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
go m.receiveACKs(ctx, flowClient)
go m.startSender(ctx)
m.shutdownWg.Add(2)
go func() {
defer m.shutdownWg.Done()
m.receiveACKs(ctx, flowClient)
}()
go func() {
defer m.shutdownWg.Done()
m.startSender(ctx)
}()
return nil
}
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
// Close cleans up all resources
func (m *Manager) Close() {
m.mux.Lock()
defer m.mux.Unlock()
if err := m.disableFlow(); err != nil {
log.Warnf("failed to disable flow manager: %v", err)
}
m.mux.Unlock()
m.shutdownWg.Wait()
}
// GetLogger returns the flow logger

View File

@@ -1,4 +1,4 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
//go:build dragonfly || freebsd || netbsd || openbsd
package networkmonitor

View File

@@ -0,0 +1,344 @@
//go:build darwin && !ios
package networkmonitor
import (
"context"
"errors"
"fmt"
"hash/fnv"
"net/netip"
"os/exec"
"syscall"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// todo: refactor to not use static functions
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %v", err)
}
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}()
routeChanged := make(chan struct{})
go func() {
routeCheck(ctx, fd, nexthopv4, nexthopv6)
close(routeChanged)
}()
wakeUp := make(chan struct{})
go func() {
wakeUpListen(ctx)
close(wakeUp)
}()
gatewayChanged := make(chan string)
go func() {
gatewayPoll(ctx, nexthopv4, nexthopv6, gatewayChanged)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-routeChanged:
log.Infof("route change detected via routing socket")
return nil
case <-wakeUp:
log.Infof("wakeup detected via sleep hash change")
return nil
case reason := <-gatewayChanged:
log.Infof("gateway change detected via polling: %s", reason)
return nil
}
}
func routeCheck(ctx context.Context, fd int, nexthopv4 systemops.Nexthop, nexthopv6 systemops.Nexthop) {
for {
if ctx.Err() != nil {
return
}
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return
}
}
}
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
}
func wakeUpListen(ctx context.Context) {
log.Infof("start to watch for system wakeups")
var (
initialHash uint32
err error
)
// Keep retrying until initial sysctl succeeds or context is canceled
for {
select {
case <-ctx.Done():
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
return
default:
initialHash, err = readSleepTimeHash()
if err != nil {
log.Errorf("failed to detect initial sleep time: %v", err)
select {
case <-ctx.Done():
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
return
case <-time.After(3 * time.Second):
continue
}
}
log.Infof("initial wakeup hash: %d", initialHash)
break
}
break
}
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
lastCheck := time.Now()
const maxTickerDrift = 1 * time.Minute
for {
select {
case <-ctx.Done():
log.Info("context canceled, stopping wakeUpListen")
return
case <-ticker.C:
now := time.Now()
elapsed := now.Sub(lastCheck)
// If more time passed than expected, system likely slept (informational only)
if elapsed > maxTickerDrift {
upOut, err := exec.Command("uptime").Output()
if err != nil {
log.Errorf("failed to run uptime command: %v", err)
upOut = []byte("unknown")
}
log.Infof("Time drift detected (potential wakeup): expected ~5s, actual %s, uptime: %s", elapsed, upOut)
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
if errV4 == nil {
log.Infof("Current IPv4 default gateway: %s via %s", currentV4.IP, currentV4.Intf.Name)
} else {
log.Debugf("No IPv4 default gateway: %v", errV4)
}
if errV6 == nil {
log.Infof("Current IPv6 default gateway: %s via %s", currentV6.IP, currentV6.Intf.Name)
} else {
log.Debugf("No IPv6 default gateway: %v", errV6)
}
}
newHash, err := readSleepTimeHash()
if err != nil {
log.Errorf("failed to read sleep time hash: %v", err)
lastCheck = now
continue
}
if newHash == initialHash {
log.Debugf("no wakeup detected (hash unchanged: %d, time drift: %s)", initialHash, elapsed)
lastCheck = now
continue
}
upOut, err := exec.Command("uptime").Output()
if err != nil {
log.Errorf("failed to run uptime command: %v", err)
upOut = []byte("unknown")
}
log.Infof("Wakeup detected via hash change: %d -> %d, uptime: %s", initialHash, newHash, upOut)
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
if errV4 == nil {
log.Infof("Current IPv4 default gateway after wakeup: %s via %s", currentV4.IP, currentV4.Intf.Name)
} else {
log.Debugf("No IPv4 default gateway after wakeup: %v", errV4)
}
if errV6 == nil {
log.Infof("Current IPv6 default gateway after wakeup: %s via %s", currentV6.IP, currentV6.Intf.Name)
} else {
log.Debugf("No IPv6 default gateway after wakeup: %v", errV6)
}
return
}
}
}
func readSleepTimeHash() (uint32, error) {
cmd := exec.Command("sysctl", "kern.sleeptime")
out, err := cmd.Output()
if err != nil {
return 0, fmt.Errorf("failed to run sysctl: %w", err)
}
h, err := hash(out)
if err != nil {
return 0, fmt.Errorf("failed to compute hash: %w", err)
}
return h, nil
}
func hash(data []byte) (uint32, error) {
hasher := fnv.New32a()
if _, err := hasher.Write(data); err != nil {
return 0, err
}
return hasher.Sum32(), nil
}
// gatewayPoll polls the default gateway every 5 seconds to detect changes that might be missed by routing socket or wake-up detection.
func gatewayPoll(ctx context.Context, initialV4, initialV6 systemops.Nexthop, changed chan<- string) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
log.Infof("Gateway polling started - initial v4: %s via %v, v6: %s via %v",
initialV4.IP, initialV4.Intf, initialV6.IP, initialV6.Intf)
for {
select {
case <-ctx.Done():
log.Debug("context canceled, stopping gateway polling")
return
case <-ticker.C:
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
var reason string
if errV4 == nil && initialV4.IP.IsValid() {
if currentV4.IP.Compare(initialV4.IP) != 0 {
reason = fmt.Sprintf("IPv4 gateway changed from %s to %s", initialV4.IP, currentV4.IP)
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
if initialV4.Intf != nil && currentV4.Intf != nil && currentV4.Intf.Name != initialV4.Intf.Name {
reason = fmt.Sprintf("IPv4 interface changed from %s to %s", initialV4.Intf.Name, currentV4.Intf.Name)
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
} else if errV4 == nil && !initialV4.IP.IsValid() {
reason = "IPv4 gateway appeared"
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV4.IP)
changed <- reason
return
} else if errV4 != nil && initialV4.IP.IsValid() {
reason = "IPv4 gateway disappeared"
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
if errV6 == nil && initialV6.IP.IsValid() {
if currentV6.IP.Compare(initialV6.IP) != 0 {
reason = fmt.Sprintf("IPv6 gateway changed from %s to %s", initialV6.IP, currentV6.IP)
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
if initialV6.Intf != nil && currentV6.Intf != nil && currentV6.Intf.Name != initialV6.Intf.Name {
reason = fmt.Sprintf("IPv6 interface changed from %s to %s", initialV6.Intf.Name, currentV6.Intf.Name)
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
} else if errV6 == nil && !initialV6.IP.IsValid() {
reason = "IPv6 gateway appeared"
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV6.IP)
changed <- reason
return
} else if errV6 != nil && initialV6.IP.IsValid() {
reason = "IPv6 gateway disappeared"
log.Infof("Gateway poll detected change: %s", reason)
changed <- reason
return
}
log.Debugf("Gateway poll: no change detected")
}
}
}

View File

@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
log.Infof("start watching for network changes")
// debounce changes
timer := time.NewTimer(0)
timer.Stop()

View File

@@ -19,11 +19,11 @@ type SRWatcher struct {
signalClient chNotifier
relayManager chNotifier
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
listeners map[chan struct{}]struct{}
mu sync.Mutex
shutdownWg sync.WaitGroup
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
cancelIceMonitor context.CancelFunc
}
@@ -52,7 +52,11 @@ func (w *SRWatcher) Start() {
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
w.shutdownWg.Add(1)
go func() {
defer w.shutdownWg.Done()
iceMonitor.Start(ctx, w.onICEChanged)
}()
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)
@@ -60,14 +64,16 @@ func (w *SRWatcher) Start() {
func (w *SRWatcher) Close() {
w.mu.Lock()
defer w.mu.Unlock()
if w.cancelIceMonitor == nil {
w.mu.Unlock()
return
}
w.cancelIceMonitor()
w.signalClient.SetOnReconnectedListener(nil)
w.relayManager.SetOnReconnectedListener(nil)
w.mu.Unlock()
w.shutdownWg.Wait()
}
func (w *SRWatcher) NewListener() chan struct{} {

View File

@@ -78,6 +78,7 @@ type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
shutdownWg sync.WaitGroup
clientNetworks map[route.HAUniqueID]*client.Watcher
routeSelector *routeselector.RouteSelector
serverRouter *server.Router
@@ -106,7 +107,7 @@ type DefaultManager struct {
func NewManager(config ManagerConfig) *DefaultManager {
mCTX, cancel := context.WithCancel(config.Context)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
sysOps := systemops.New(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
@@ -273,6 +274,7 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop()
m.shutdownWg.Wait()
if m.serverRouter != nil {
m.serverRouter.CleanUp()
}
@@ -474,7 +476,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
}
clientNetworkWatcher := client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
}
@@ -516,7 +522,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
}
clientNetworkWatcher = client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
}
update := client.RoutesUpdate{
UpdateSerial: updateSerial,

View File

@@ -0,0 +1,8 @@
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
package systemops
// FlushMarkedRoutes is a no-op on non-BSD platforms.
func (r *SysOps) FlushMarkedRoutes() error {
return nil
}

View File

@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData((*ExclusionCounter)(s))
sysOps := New(nil, nil)
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
sysOps.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush()
return sysOps.refCounter.Flush()
}
func (s *ShutdownState) MarshalJSON() ([]byte, error) {

View File

@@ -83,7 +83,7 @@ type SysOps struct {
localSubnetsCacheTime time.Time
}
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,

View File

@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil)
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil)
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close())
})
r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err")

View File

@@ -7,19 +7,39 @@ import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
"syscall"
"time"
"unsafe"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
)
var routeProtoFlag int
func init() {
switch os.Getenv(envRouteProtoFlag) {
case "2":
routeProtoFlag = unix.RTF_PROTO2
case "3":
routeProtoFlag = unix.RTF_PROTO3
default:
routeProtoFlag = unix.RTF_PROTO1
}
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager)
}
@@ -28,12 +48,88 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
return r.cleanupRefCounter(stateManager)
}
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
func (r *SysOps) FlushMarkedRoutes() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("Skipping route flush: %v", err)
continue
}
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
continue
}
nexthop := Nexthop{
IP: routeInfo.Gw,
Intf: routeInfo.Interface,
}
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
continue
}
flushedCount++
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
}
if flushedCount > 0 {
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
if prefix.IsSingleIP() {
log.Debugf("Adding single IP route: %s via %s", prefix, formatNexthop(nexthop))
}
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
if prefix.IsSingleIP() {
log.Debugf("Removing single IP route: %s via %s", prefix, formatNexthop(nexthop))
}
if err := r.routeSocket(unix.RTM_DELETE, prefix, nexthop); err != nil {
return err
}
if prefix.IsSingleIP() {
log.Debugf("Route removal completed for %s, verifying...", prefix)
if exists := r.verifyRouteRemoved(prefix); exists {
log.Warnf("Route %s still exists in routing table after removal", prefix)
} else {
log.Debugf("Verified route %s successfully removed", prefix)
}
}
return nil
}
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
@@ -105,7 +201,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
}
@@ -200,3 +296,51 @@ func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
}
// formatNexthop returns a string representation of the nexthop for logging.
func formatNexthop(nexthop Nexthop) string {
if nexthop.IP.IsValid() {
return nexthop.IP.String()
}
if nexthop.Intf != nil {
return nexthop.Intf.Name
}
return "direct"
}
// verifyRouteRemoved checks if a route still exists in the routing table.
func (r *SysOps) verifyRouteRemoved(prefix netip.Prefix) bool {
rib, err := retryFetchRIB()
if err != nil {
log.Debugf("Failed to fetch RIB for route verification: %v", err)
return false
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
log.Debugf("Failed to parse RIB for route verification: %v", err)
return false
}
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
continue
}
if routeInfo.Dst == prefix {
return true
}
}
return false
}

View File

@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
data, err := os.ReadFile(m.filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist")
log.Debugf("state file %s does not exist", m.filePath)
return nil, nil // nolint:nilnil
}
return nil, fmt.Errorf("read state file: %w", err)

View File

@@ -0,0 +1,59 @@
package winregistry
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows/registry"
)
var (
advapi = syscall.NewLazyDLL("advapi32.dll")
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
)
const (
// Registry key options
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
// Registry disposition values
regCreatedNewKey = 0x1
regOpenedExistingKey = 0x2
)
// CreateVolatileKey creates a volatile registry key named path under open key root.
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
// The access parameter specifies the access rights for the key to be created.
//
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
// This provides automatic cleanup without requiring manual registry maintenance.
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
pathPtr, err := syscall.UTF16PtrFromString(path)
if err != nil {
return 0, false, err
}
var (
handle syscall.Handle
disposition uint32
)
ret, _, _ := regCreateKeyExW.Call(
uintptr(root),
uintptr(unsafe.Pointer(pathPtr)),
0, // reserved
0, // class
uintptr(regOptionVolatile), // options - volatile key
uintptr(access), // desired access
0, // security attributes
uintptr(unsafe.Pointer(&handle)),
uintptr(unsafe.Pointer(&disposition)),
)
if ret != 0 {
return 0, false, syscall.Errno(ret)
}
return registry.Key(handle), disposition == regOpenedExistingKey, nil
}

View File

@@ -17,8 +17,7 @@ type Conn struct {
ID hooks.ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn)
}
@@ -29,7 +28,7 @@ type TCPConn struct {
ID hooks.ConnectionID
}
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn)
}
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
// closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close()
cleanupConnID(id)
return err
}
// cleanupConnID executes close hooks for a connection ID.
func cleanupConnID(id hooks.ConnectionID) {
closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks {
if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err)
}
}
return err
}

View File

@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
}
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}

View File

@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
cleanupConnID(connID)
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
return fmt.Errorf("resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)

View File

@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.PacketConn.WriteTo(b, addr)
}
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn)
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.UDPConn.WriteTo(b, addr)
}
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn)

View File

@@ -10,7 +10,9 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
}
// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -205,15 +205,18 @@ func mapPeers(
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := "P2P"
connType := "-"
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if pbPeerState.Relayed {
connType = "Relayed"
if isPeerConnected {
connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
}
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {

View File

@@ -31,7 +31,6 @@ import (
"fyne.io/systray"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
}
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
err := open.Run(loginResp.VerificationURIComplete)
err := openURL(loginResp.VerificationURIComplete)
if err != nil {
log.Errorf("opening the verification uri in the browser failed: %v", err)
return err
@@ -1487,6 +1486,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
}
func openURL(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
var err error
switch runtime.GOOS {
case "windows":

View File

@@ -18,6 +18,7 @@ import (
"github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types"
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
return "", err
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get post-up status: %v", err)
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err)
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err)
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}

2
go.mod
View File

@@ -62,7 +62,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
echo ""
export NETBIRD_SIGNAL_PROTOCOL="https"
unset NETBIRD_LETSENCRYPT_DOMAIN
unset NETBIRD_MGMT_API_CERT_FILE
unset NETBIRD_MGMT_API_CERT_KEY_FILE
fi
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
export NETBIRD_SIGNAL_PROTOCOL="https"
fi
# Check if management identity provider is set
if [ -n "$NETBIRD_MGMT_IDP" ]; then
EXTRA_CONFIG={}

View File

@@ -40,13 +40,21 @@ services:
signal:
<<: *default
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
depends_on:
- dashboard
volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
ports:
- $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
command: [
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
"--log-file", "console"
]
# Relay
relay:

View File

@@ -682,17 +682,6 @@ renderManagementJson() {
"URI": "stun:$NETBIRD_DOMAIN:3478"
}
],
"TURNConfig": {
"Turns": [
{
"Proto": "udp",
"URI": "turn:$NETBIRD_DOMAIN:3478",
"Username": "$TURN_USER",
"Password": "$TURN_PASSWORD"
}
],
"TimeBasedCredentials": false
},
"Relay": {
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
"CredentialsTTL": "24h",

View File

@@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager {
return integrations.InitPermissionsManager(s.Store())
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
s.AfterInit(func(s *BaseServer) {
manager.SetAccountManager(s.AccountManager())
})
return manager
})
}

View File

@@ -109,7 +109,7 @@ type Manager interface {
GetIdpManager() idp.Manager
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error

View File

@@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
@@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
}
_, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid))
reason := invalidPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
_, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid))
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
h.setApprovalRequiredFlag(respBody, validPeersMap)
h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap)
util.WriteJSONObject(r.Context(), w, respBody)
}
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) {
for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id]
_, ok := validPeersMap[peer.Id]
if !ok {
peer.ApprovalRequired = true
reason := invalidPeersMap[peer.Id]
peer.DisapprovalReason = &reason
}
}
}
@@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
}
}
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}
}
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
}
return &api.Peer{
apiPeer := &api.Peer{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
Name: peer.Name,
@@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
}
if !approved {
apiPeer.DisapprovalReason = &reason
}
return apiPeer
}
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {

View File

@@ -7,9 +7,10 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"

View File

@@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
return true, nil
}
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
var err error
var groups []*types.Group
var peers []*nbpeer.Peer
@@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
return nil, nil, err
}
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return nil, err
return nil, nil, err
}
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
return nil, nil, err
}
return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
if err != nil {
return nil, nil, err
}
invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra)
if err != nil {
return nil, nil, err
}
return validPeers, invalidPeers, nil
}
type MockIntegratedValidator struct {
@@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil
}
func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) {
return make(map[string]string), nil
}
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
return peer
}

View File

@@ -15,6 +15,7 @@ type IntegratedValidator interface {
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
Stop(ctx context.Context)

View File

@@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
return nil, err
return nil, nil, err
}
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
approvedPeers[id] = struct{}{}
}
return approvedPeers, nil
return approvedPeers, nil, nil
}
// GetGroup mock implementation of GetGroup from server.AccountManager interface

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -22,6 +23,7 @@ type Manager interface {
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
SetAccountManager(accountManager account.Manager)
}
type managerImpl struct {
@@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
return permissions, nil
}
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
// no-op
}

View File

@@ -9,6 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
account "github.com/netbirdio/netbird/management/server/account"
modules "github.com/netbirdio/netbird/management/server/permissions/modules"
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
@@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
}
// SetAccountManager mocks base method.
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetAccountManager", accountManager)
}
// SetAccountManager indicates an expected call of SetAccountManager.
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
}
// ValidateAccountAccess mocks base method.
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
m.ctrl.T.Helper()

View File

@@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
@@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool {
}
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
peerIPs := make(map[string]struct{})
@@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
peerIPs[peerToConnect.IP.String()] = struct{}{}
}
for _, expiredPeer := range expiredPeers {
peerIPs[expiredPeer.IP.String()] = struct{}{}
}
for _, record := range customZone.Records {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)

View File

@@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
peer *nbpeer.Peer
customZone nbdns.CustomZone
peersToConnect []*nbpeer.Peer
expiredPeers []*nbpeer.Peer
expectedRecords []nbdns.SimpleRecord
}{
{
@@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
},
},
peersToConnect: []*nbpeer.Peer{},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
@@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
}
return peers
}(),
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
@@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
@@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
{
name: "expired peers are included in DNS entries",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
peersToConnect: []*nbpeer.Peer{
{ID: "peer1", IP: net.ParseIP("10.0.0.1")},
},
expiredPeers: []*nbpeer.Peer{
{ID: "expired-peer", IP: net.ParseIP("10.0.0.99")},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers)
assert.Equal(t, len(tt.expectedRecords), len(result))
assert.ElementsMatch(t, tt.expectedRecords, result)
})

View File

@@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then
NETBIRD_RELEASE=latest
fi
TAG_NAME=""
get_release() {
local RELEASE=$1
if [ "$RELEASE" = "latest" ]; then
@@ -38,17 +40,19 @@ get_release() {
local TAG="tags/${RELEASE}"
local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
fi
OUTPUT=""
if [ -n "$GITHUB_TOKEN" ]; then
curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}")
else
curl -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
OUTPUT=$(curl -s "${URL}")
fi
TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1)
echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+'
}
download_release_binary() {
VERSION=$(get_release "$NETBIRD_RELEASE")
echo "Using the following tag name for binary installation: ${TAG_NAME}"
BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download"
BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz"

View File

@@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
return fmt.Errorf("create connection: %w", err)
}
return nil
}

View File

@@ -463,6 +463,9 @@ components:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
example: true
disapproval_reason:
description: (Cloud only) Reason why the peer requires approval
type: string
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:

View File

@@ -1037,6 +1037,9 @@ type Peer struct {
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
// DisapprovalReason (Cloud only) Reason why the peer requires approval
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
@@ -1124,6 +1127,9 @@ type PeerBatch struct {
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
// DisapprovalReason (Cloud only) Reason why the peer requires approval
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`

View File

@@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
return fmt.Errorf("create connection: %w", err)
}
return nil
}

View File

@@ -94,7 +94,7 @@ var (
startPprof()
opts, certManager, err := getTLSConfigurations()
opts, certManager, tlsConfig, err := getTLSConfigurations()
if err != nil {
return err
}
@@ -132,7 +132,7 @@ var (
// Start the main server - always serve HTTP with WebSocket proxy support
// If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager
if certManager == nil {
if tlsConfig == nil {
// Without TLS, serve plain HTTP
httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort))
if err != nil {
@@ -140,9 +140,10 @@ var (
}
log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String())
serveHTTP(httpListener, grpcRootHandler)
} else if signalPort != 443 {
// With TLS but not on port 443, serve HTTPS
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig())
} else if certManager == nil || signalPort != 443 {
// Serve HTTPS if not already handled by startServerWithCertManager
// (custom certificates or Let's Encrypt with custom port)
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig)
if err != nil {
return err
}
@@ -202,7 +203,7 @@ func startPprof() {
}()
}
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) {
var (
err error
certManager *autocert.Manager
@@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
log.Infof("running without TLS")
return nil, nil, nil
return nil, nil, nil, nil
}
if signalLetsencryptDomain != "" {
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
if err != nil {
return nil, certManager, err
return nil, certManager, nil, err
}
tlsConfig = certManager.TLSConfig()
log.Infof("setting up TLS with LetsEncrypt.")
} else {
if signalCertFile == "" || signalCertKey == "" {
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
}
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
return nil, certManager, err
return nil, certManager, nil, err
}
log.Infof("setting up TLS with custom certificates.")
}
transportCredentials := credentials.NewTLS(tlsConfig)
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err
}
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {