mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
minor changes after merge main
This commit is contained in:
@@ -4,7 +4,7 @@
|
|||||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||||
|
|
||||||
FROM alpine:3.22.0
|
FROM alpine:3.22.2
|
||||||
# iproute2: busybox doesn't display ip rules properly
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
bash \
|
bash \
|
||||||
|
|||||||
@@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||||
} else {
|
} else {
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
|
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return statusOutputString
|
return statusOutputString
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
if !noBrowser {
|
if !noBrowser {
|
||||||
if err := open.Run(verificationURIComplete); err != nil {
|
if err := openBrowser(verificationURIComplete); err != nil {
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
||||||
|
func openBrowser(url string) error {
|
||||||
|
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||||
|
return exec.Command(browser, url).Start()
|
||||||
|
}
|
||||||
|
return open.Run(url)
|
||||||
|
}
|
||||||
|
|
||||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isUnixRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
|||||||
@@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include action in the ipset name to prevent squashing rules with different actions
|
|
||||||
actionSuffix := ""
|
actionSuffix := ""
|
||||||
if action == firewall.ActionDrop {
|
if action == firewall.ActionDrop {
|
||||||
actionSuffix = "-drop"
|
actionSuffix = "-drop"
|
||||||
|
|||||||
@@ -29,11 +29,6 @@ type Manager interface {
|
|||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type protoMatch struct {
|
|
||||||
ips map[string]int
|
|
||||||
policyID []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
@@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
rules := networkMap.FirewallRules
|
||||||
|
|
||||||
enableSSH := networkMap.PeerConfig != nil &&
|
enableSSH := networkMap.PeerConfig != nil &&
|
||||||
networkMap.PeerConfig.SshConfig != nil &&
|
networkMap.PeerConfig.SshConfig != nil &&
|
||||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
|
||||||
enableSSH = enableSSH && !ok
|
|
||||||
}
|
|
||||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
|
|
||||||
enableSSH = enableSSH && !ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// if TCP protocol rules not squashed and SSH enabled
|
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
||||||
// we add default firewall rule which accepts connection to any peer
|
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
||||||
// in the network by SSH (TCP 22 port).
|
|
||||||
if enableSSH {
|
if enableSSH {
|
||||||
rules = append(rules, &mgmProto.FirewallRule{
|
rules = append(rules, &mgmProto.FirewallRule{
|
||||||
PeerIP: "0.0.0.0",
|
PeerIP: "0.0.0.0",
|
||||||
@@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID(
|
|||||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||||
}
|
}
|
||||||
|
|
||||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
|
||||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
|
||||||
//
|
|
||||||
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
|
||||||
// but other has port definitions or has drop policy.
|
|
||||||
func (d *DefaultManager) squashAcceptRules(
|
|
||||||
networkMap *mgmProto.NetworkMap,
|
|
||||||
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
|
|
||||||
totalIPs := 0
|
|
||||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
|
||||||
for range p.AllowedIps {
|
|
||||||
totalIPs++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
in := map[mgmProto.RuleProtocol]*protoMatch{}
|
|
||||||
out := map[mgmProto.RuleProtocol]*protoMatch{}
|
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
|
||||||
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
|
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
|
||||||
// But we zeroed the IP's for protocol if:
|
|
||||||
// 1. Any of the rule has DROP action type.
|
|
||||||
// 2. Any of rule contains Port.
|
|
||||||
//
|
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
|
||||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
|
||||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
|
||||||
|
|
||||||
if hasPortRestrictions {
|
|
||||||
// Don't squash rules with port restrictions
|
|
||||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
|
||||||
protocols[r.Protocol] = &protoMatch{
|
|
||||||
ips: map[string]int{},
|
|
||||||
// store the first encountered PolicyID for this protocol
|
|
||||||
policyID: r.PolicyID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// special case, when we receive this all network IP address
|
|
||||||
// it means that rules for that protocol was already optimized on the
|
|
||||||
// management side
|
|
||||||
if r.PeerIP == "0.0.0.0" {
|
|
||||||
squashedRules = append(squashedRules, r)
|
|
||||||
squashedProtocols[r.Protocol] = struct{}{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ipset := protocols[r.Protocol].ips
|
|
||||||
|
|
||||||
if _, ok := ipset[r.PeerIP]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ipset[r.PeerIP] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
|
||||||
// calculate squash for different directions
|
|
||||||
if r.Direction == mgmProto.RuleDirection_IN {
|
|
||||||
addRuleToCalculationMap(i, r, in)
|
|
||||||
} else {
|
|
||||||
addRuleToCalculationMap(i, r, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// order of squashing by protocol is important
|
|
||||||
// only for their first element ALL, it must be done first
|
|
||||||
protocolOrders := []mgmProto.RuleProtocol{
|
|
||||||
mgmProto.RuleProtocol_ALL,
|
|
||||||
mgmProto.RuleProtocol_ICMP,
|
|
||||||
mgmProto.RuleProtocol_TCP,
|
|
||||||
mgmProto.RuleProtocol_UDP,
|
|
||||||
}
|
|
||||||
|
|
||||||
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
|
|
||||||
for _, protocol := range protocolOrders {
|
|
||||||
match, ok := matches[protocol]
|
|
||||||
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
|
|
||||||
// don't squash if :
|
|
||||||
// 1. Rules not cover all peers in the network
|
|
||||||
// 2. Rules cover only one peer in the network.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
|
|
||||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: direction,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: protocol,
|
|
||||||
PolicyID: match.policyID,
|
|
||||||
})
|
|
||||||
squashedProtocols[protocol] = struct{}{}
|
|
||||||
|
|
||||||
if protocol == mgmProto.RuleProtocol_ALL {
|
|
||||||
// if we have ALL traffic type squashed rule
|
|
||||||
// it allows all other type of traffic, so we can stop processing
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
squash(in, mgmProto.RuleDirection_IN)
|
|
||||||
squash(out, mgmProto.RuleDirection_OUT)
|
|
||||||
|
|
||||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
|
||||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
|
||||||
return squashedRules, squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(squashedRules) == 0 {
|
|
||||||
return networkMap.FirewallRules, squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []*mgmProto.FirewallRule
|
|
||||||
// filter out rules which was squashed from final list
|
|
||||||
// if we also have other not squashed rules.
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
|
||||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
|
||||||
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
|
||||||
continue
|
|
||||||
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rules = append(rules, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, squashedRules...), squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||||
|
|||||||
@@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerSquashRules(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
assert.Equal(t, 2, len(rules))
|
|
||||||
|
|
||||||
r := rules[0]
|
|
||||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
|
||||||
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
|
||||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
|
||||||
|
|
||||||
r = rules[1]
|
|
||||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
|
||||||
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
|
||||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rules []*mgmProto.FirewallRule
|
|
||||||
expectedCount int
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "should not squash rules with port ranges",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with specific ports",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with legacy port field",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with legacy port field should not be squashed",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with DROP action",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with DROP action should not be squashed",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should squash rules without port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 1,
|
|
||||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed rules should not squash protocol with port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "TCP should not be squashed because one rule has port restrictions",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
// TCP rules with port restrictions - should NOT be squashed
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
// UDP rules without port restrictions - SHOULD be squashed
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
|
||||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: tt.rules,
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
|
||||||
|
|
||||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
|
||||||
if tt.expectedCount == 1 {
|
|
||||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
|
||||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
|
|||||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||||
config.txt: Anonymized configuration information of the NetBird client.
|
config.txt: Anonymized configuration information of the NetBird client.
|
||||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||||
state.json: Anonymized client state dump containing netbird states.
|
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||||
mutex.prof: Mutex profiling information.
|
mutex.prof: Mutex profiling information.
|
||||||
goroutine.prof: Goroutine profiling information.
|
goroutine.prof: Goroutine profiling information.
|
||||||
block.prof: Block profiling information.
|
block.prof: Block profiling information.
|
||||||
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("Adding state file from: %s", path)
|
||||||
|
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
)
|
)
|
||||||
|
|
||||||
err = s.recordSystemDNSSettings(true)
|
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
searchDomains = append(searchDomains, "\"\"")
|
searchDomains = append(searchDomains, "\"\"")
|
||||||
err = s.addLocalDNS()
|
if err := s.addLocalDNS(); err != nil {
|
||||||
if err != nil {
|
log.Warnf("failed to add local DNS: %v", err)
|
||||||
log.Infof("failed to enable split DNS")
|
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.Domains {
|
||||||
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
var err error
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||||
} else {
|
} else {
|
||||||
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add match domains: %w", err)
|
return fmt.Errorf("add match domains: %w", err)
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
if len(searchDomains) != 0 {
|
if len(searchDomains) != 0 {
|
||||||
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
|
|
||||||
if err := s.flushDNSCache(); err != nil {
|
if err := s.flushDNSCache(); err != nil {
|
||||||
log.Errorf("failed to flush DNS cache: %v", err)
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) string() string {
|
func (s *systemConfigurator) string() string {
|
||||||
return "scutil"
|
return "scutil"
|
||||||
}
|
}
|
||||||
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
|||||||
func (s *systemConfigurator) addLocalDNS() error {
|
func (s *systemConfigurator) addLocalDNS() error {
|
||||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||||
log.Errorf("Unable to get system DNS configuration")
|
|
||||||
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
|
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||||
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Info("Not enabling local DNS server")
|
log.Info("Not enabling local DNS server")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.addSearchDomains(
|
||||||
|
localKey,
|
||||||
|
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
111
client/internal/dns/host_darwin_test.go
Normal file
111
client/internal/dns/host_darwin_test.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping scutil integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
stateFile := filepath.Join(tmpDir, "state.json")
|
||||||
|
|
||||||
|
sm := statemanager.New(stateFile)
|
||||||
|
sm.RegisterState(&ShutdownState{})
|
||||||
|
sm.Start()
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, sm.Stop(context.Background()))
|
||||||
|
}()
|
||||||
|
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: true,
|
||||||
|
Domains: []DomainConfig{
|
||||||
|
{Domain: "example.com", MatchOnly: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, sm.PersistState(context.Background()))
|
||||||
|
|
||||||
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if exists {
|
||||||
|
t.Logf("Key %s exists before cleanup", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sm2 := statemanager.New(stateFile)
|
||||||
|
sm2.RegisterState(&ShutdownState{})
|
||||||
|
err = sm2.LoadState(&ShutdownState{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
state := sm2.GetState(&ShutdownState{})
|
||||||
|
if state == nil {
|
||||||
|
t.Skip("State not saved, skipping cleanup test")
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownState, ok := state.(*ShutdownState)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
err = shutdownState.Cleanup()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkDNSKeyExists(key string) (bool, error) {
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(string(output), "No such key") {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return !strings.Contains(string(output), "No such key"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeTestDNSKey(key string) error {
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
|
||||||
|
_, err := cmd.CombinedOutput()
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
r.updateState(stateManager)
|
||||||
Guid: r.guid,
|
|
||||||
GPO: r.gpo,
|
|
||||||
NRPTEntryCount: r.nrptEntryCount,
|
|
||||||
}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var searchDomains, matchDomains []string
|
var searchDomains, matchDomains []string
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.Domains {
|
||||||
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||||
|
log.Errorf("cleanup old dns match policies: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
}
|
}
|
||||||
r.nrptEntryCount = count
|
r.nrptEntryCount = count
|
||||||
} else {
|
} else {
|
||||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
|
||||||
return fmt.Errorf("remove dns match policies: %w", err)
|
|
||||||
}
|
|
||||||
r.nrptEntryCount = 0
|
r.nrptEntryCount = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
r.updateState(stateManager)
|
||||||
Guid: r.guid,
|
|
||||||
GPO: r.gpo,
|
|
||||||
NRPTEntryCount: r.nrptEntryCount,
|
|
||||||
}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
|
Guid: r.guid,
|
||||||
|
GPO: r.gpo,
|
||||||
|
NRPTEntryCount: r.nrptEntryCount,
|
||||||
|
}); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
||||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||||
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
|
|||||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||||
}
|
}
|
||||||
defer closer(regKey)
|
defer closer(regKey)
|
||||||
|
|
||||||
|
|||||||
102
client/internal/dns/host_windows_test.go
Normal file
102
client/internal/dns/host_windows_test.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||||
|
// when the number of match domains decreases between configuration changes.
|
||||||
|
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping registry integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer cleanupRegistryKeys(t)
|
||||||
|
cleanupRegistryKeys(t)
|
||||||
|
|
||||||
|
testIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||||
|
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||||
|
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||||
|
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||||
|
require.NoError(t, err, "Should create test interface registry key")
|
||||||
|
testKey.Close()
|
||||||
|
defer func() {
|
||||||
|
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := ®istryConfigurator{
|
||||||
|
guid: testGUID,
|
||||||
|
gpo: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
config5 := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: []DomainConfig{
|
||||||
|
{Domain: "domain1.com", MatchOnly: true},
|
||||||
|
{Domain: "domain2.com", MatchOnly: true},
|
||||||
|
{Domain: "domain3.com", MatchOnly: true},
|
||||||
|
{Domain: "domain4.com", MatchOnly: true},
|
||||||
|
{Domain: "domain5.com", MatchOnly: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cfg.applyDNSConfig(config5, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify all 5 entries exist
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
config2 := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: []DomainConfig{
|
||||||
|
{Domain: "domain1.com", MatchOnly: true},
|
||||||
|
{Domain: "domain2.com", MatchOnly: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cfg.applyDNSConfig(config2, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify first 2 entries exist
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify entries 2-4 are cleaned up
|
||||||
|
for i := 2; i < 5; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func registryKeyExists(path string) (bool, error) {
|
||||||
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
if err == registry.ErrNotExist {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
k.Close()
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRegistryKeys(*testing.T) {
|
||||||
|
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||||
|
_ = cfg.removeDNSMatchPolicies()
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
|
CreatedKeys []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
return fmt.Errorf("create host manager: %w", err)
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, key := range s.CreatedKeys {
|
||||||
|
manager.createdKeys[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ type DefaultManager struct {
|
|||||||
func NewManager(config ManagerConfig) *DefaultManager {
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(config.Context)
|
mCTX, cancel := context.WithCancel(config.Context)
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
sysOps := systemops.New(config.WGInterface, notifier)
|
||||||
|
|
||||||
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||||
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||||
|
|||||||
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
// FlushMarkedRoutes is a no-op on non-BSD platforms.
|
||||||
|
func (r *SysOps) FlushMarkedRoutes() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
sysops := NewSysOps(nil, nil)
|
sysOps := New(nil, nil)
|
||||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
|
||||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
sysOps.refCounter.LoadData((*ExclusionCounter)(s))
|
||||||
|
|
||||||
return sysops.refCounter.Flush()
|
return sysOps.refCounter.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ type SysOps struct {
|
|||||||
localSubnetsCacheTime time.Time
|
localSubnetsCacheTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||||
return &SysOps{
|
return &SysOps{
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
|||||||
_, intf = setupDummyInterface(t)
|
_, intf = setupDummyInterface(t)
|
||||||
nexthop = Nexthop{netip.Addr{}, intf}
|
nexthop = Nexthop{netip.Addr{}, intf}
|
||||||
|
|
||||||
r := NewSysOps(nil, nil)
|
r := New(nil, nil)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < 1024; i++ {
|
for i := 0; i < 1024; i++ {
|
||||||
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
|
|||||||
|
|
||||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||||
|
|
||||||
r := NewSysOps(nil, nil)
|
r := New(nil, nil)
|
||||||
err = r.addToRouteTable(prefix, nexthop)
|
err = r.addToRouteTable(prefix, nexthop)
|
||||||
require.NoError(t, err, "Failed to add route to table")
|
require.NoError(t, err, "Failed to add route to table")
|
||||||
|
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
|||||||
|
|
||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
|||||||
|
|
||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
|
|||||||
assert.NoError(t, wgInterface.Close())
|
assert.NoError(t, wgInterface.Close())
|
||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
|
|||||||
@@ -7,19 +7,39 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||||
|
)
|
||||||
|
|
||||||
|
var routeProtoFlag int
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
switch os.Getenv(envRouteProtoFlag) {
|
||||||
|
case "2":
|
||||||
|
routeProtoFlag = unix.RTF_PROTO2
|
||||||
|
case "3":
|
||||||
|
routeProtoFlag = unix.RTF_PROTO3
|
||||||
|
default:
|
||||||
|
routeProtoFlag = unix.RTF_PROTO1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
@@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
|
|||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||||
|
func (r *SysOps) FlushMarkedRoutes() error {
|
||||||
|
rib, err := retryFetchRIB()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetch routing table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse routing table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
flushedCount := 0
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
rtMsg, ok := msg.(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
routeInfo, err := MsgToRoute(rtMsg)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Skipping route flush: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nexthop := Nexthop{
|
||||||
|
IP: routeInfo.Gw,
|
||||||
|
Intf: routeInfo.Interface,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
flushedCount++
|
||||||
|
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
if flushedCount > 0 {
|
||||||
|
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||||
}
|
}
|
||||||
@@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
|
|||||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||||
msg = &route.RouteMessage{
|
msg = &route.RouteMessage{
|
||||||
Type: action,
|
Type: action,
|
||||||
Flags: unix.RTF_UP,
|
Flags: unix.RTF_UP | routeProtoFlag,
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
Seq: r.getSeq(),
|
Seq: r.getSeq(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
|
|||||||
data, err := os.ReadFile(m.filePath)
|
data, err := os.ReadFile(m.filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
log.Debug("state file does not exist")
|
log.Debugf("state file %s does not exist", m.filePath)
|
||||||
return nil, nil // nolint:nilnil
|
return nil, nil // nolint:nilnil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("read state file: %w", err)
|
return nil, fmt.Errorf("read state file: %w", err)
|
||||||
|
|||||||
59
client/internal/winregistry/volatile_windows.go
Normal file
59
client/internal/winregistry/volatile_windows.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package winregistry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
advapi = syscall.NewLazyDLL("advapi32.dll")
|
||||||
|
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Registry key options
|
||||||
|
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
|
||||||
|
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
|
||||||
|
|
||||||
|
// Registry disposition values
|
||||||
|
regCreatedNewKey = 0x1
|
||||||
|
regOpenedExistingKey = 0x2
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateVolatileKey creates a volatile registry key named path under open key root.
|
||||||
|
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
|
||||||
|
// The access parameter specifies the access rights for the key to be created.
|
||||||
|
//
|
||||||
|
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
|
||||||
|
// This provides automatic cleanup without requiring manual registry maintenance.
|
||||||
|
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
|
||||||
|
pathPtr, err := syscall.UTF16PtrFromString(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
handle syscall.Handle
|
||||||
|
disposition uint32
|
||||||
|
)
|
||||||
|
|
||||||
|
ret, _, _ := regCreateKeyExW.Call(
|
||||||
|
uintptr(root),
|
||||||
|
uintptr(unsafe.Pointer(pathPtr)),
|
||||||
|
0, // reserved
|
||||||
|
0, // class
|
||||||
|
uintptr(regOptionVolatile), // options - volatile key
|
||||||
|
uintptr(access), // desired access
|
||||||
|
0, // security attributes
|
||||||
|
uintptr(unsafe.Pointer(&handle)),
|
||||||
|
uintptr(unsafe.Pointer(&disposition)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret != 0 {
|
||||||
|
return 0, false, syscall.Errno(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
return registry.Key(handle), disposition == regOpenedExistingKey, nil
|
||||||
|
}
|
||||||
@@ -10,7 +10,9 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clean up any remaining routes independently of the state file
|
||||||
|
if !nbnet.AdvancedRouting() {
|
||||||
|
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -205,15 +205,18 @@ func mapPeers(
|
|||||||
localICEEndpoint := ""
|
localICEEndpoint := ""
|
||||||
remoteICEEndpoint := ""
|
remoteICEEndpoint := ""
|
||||||
relayServerAddress := ""
|
relayServerAddress := ""
|
||||||
connType := "P2P"
|
connType := "-"
|
||||||
lastHandshake := time.Time{}
|
lastHandshake := time.Time{}
|
||||||
transferReceived := int64(0)
|
transferReceived := int64(0)
|
||||||
transferSent := int64(0)
|
transferSent := int64(0)
|
||||||
|
|
||||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||||
|
|
||||||
if pbPeerState.Relayed {
|
if isPeerConnected {
|
||||||
connType = "Relayed"
|
connType = "P2P"
|
||||||
|
if pbPeerState.Relayed {
|
||||||
|
connType = "Relayed"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
|
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ import (
|
|||||||
"fyne.io/systray"
|
"fyne.io/systray"
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
@@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||||
err := open.Run(loginResp.VerificationURIComplete)
|
err := openURL(loginResp.VerificationURIComplete)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("opening the verification uri in the browser failed: %v", err)
|
log.Errorf("opening the verification uri in the browser failed: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -1487,6 +1486,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func openURL(url string) error {
|
func openURL(url string) error {
|
||||||
|
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||||
|
return exec.Command(browser, url).Start()
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||||
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to get post-up status: %v", err)
|
log.Warnf("Failed to get post-up status: %v", err)
|
||||||
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
|
|
||||||
var postUpStatusOutput string
|
var postUpStatusOutput string
|
||||||
if postUpStatus != nil {
|
if postUpStatus != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
|
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
|
|
||||||
var preDownStatusOutput string
|
var preDownStatusOutput string
|
||||||
if preDownStatus != nil {
|
if preDownStatus != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
|
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||||
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
|||||||
return nil, fmt.Errorf("get client: %v", err)
|
return nil, fmt.Errorf("get client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get status for debug bundle: %v", err)
|
log.Warnf("failed to get status for debug bundle: %v", err)
|
||||||
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
|||||||
|
|
||||||
var statusOutput string
|
var statusOutput string
|
||||||
if statusResp != nil {
|
if statusResp != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
|
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -63,7 +63,7 @@ require (
|
|||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
|
|||||||
@@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
|
|||||||
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
|
||||||
unset NETBIRD_LETSENCRYPT_DOMAIN
|
unset NETBIRD_LETSENCRYPT_DOMAIN
|
||||||
unset NETBIRD_MGMT_API_CERT_FILE
|
unset NETBIRD_MGMT_API_CERT_FILE
|
||||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
|
||||||
|
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||||
|
fi
|
||||||
|
|
||||||
# Check if management identity provider is set
|
# Check if management identity provider is set
|
||||||
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
||||||
EXTRA_CONFIG={}
|
EXTRA_CONFIG={}
|
||||||
|
|||||||
@@ -40,13 +40,21 @@ services:
|
|||||||
signal:
|
signal:
|
||||||
<<: *default
|
<<: *default
|
||||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||||
|
depends_on:
|
||||||
|
- dashboard
|
||||||
volumes:
|
volumes:
|
||||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||||
|
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||||
ports:
|
ports:
|
||||||
- $NETBIRD_SIGNAL_PORT:80
|
- $NETBIRD_SIGNAL_PORT:80
|
||||||
# # port and command for Let's Encrypt validation
|
# # port and command for Let's Encrypt validation
|
||||||
# - 443:443
|
# - 443:443
|
||||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||||
|
command: [
|
||||||
|
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
||||||
|
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||||
|
"--log-file", "console"
|
||||||
|
]
|
||||||
|
|
||||||
# Relay
|
# Relay
|
||||||
relay:
|
relay:
|
||||||
|
|||||||
@@ -682,17 +682,6 @@ renderManagementJson() {
|
|||||||
"URI": "stun:$NETBIRD_DOMAIN:3478"
|
"URI": "stun:$NETBIRD_DOMAIN:3478"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"TURNConfig": {
|
|
||||||
"Turns": [
|
|
||||||
{
|
|
||||||
"Proto": "udp",
|
|
||||||
"URI": "turn:$NETBIRD_DOMAIN:3478",
|
|
||||||
"Username": "$TURN_USER",
|
|
||||||
"Password": "$TURN_PASSWORD"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"TimeBasedCredentials": false
|
|
||||||
},
|
|
||||||
"Relay": {
|
"Relay": {
|
||||||
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
|
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
|
||||||
"CredentialsTTL": "24h",
|
"CredentialsTTL": "24h",
|
||||||
|
|||||||
@@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
|||||||
|
|
||||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||||
return Create(s, func() permissions.Manager {
|
return Create(s, func() permissions.Manager {
|
||||||
return integrations.InitPermissionsManager(s.Store())
|
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
||||||
|
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
manager.SetAccountManager(s.AccountManager())
|
||||||
|
})
|
||||||
|
|
||||||
|
return manager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ type Manager interface {
|
|||||||
GetIdpManager() idp.Manager
|
GetIdpManager() idp.Manager
|
||||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
|||||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
||||||
|
|
||||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||||
@@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, valid := validPeers[peer.ID]
|
_, valid := validPeers[peer.ID]
|
||||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid))
|
reason := invalidPeers[peer.ID]
|
||||||
|
|
||||||
|
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
|||||||
|
|
||||||
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
||||||
|
|
||||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
|
||||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, valid := validPeers[peer.ID]
|
_, valid := validPeers[peer.ID]
|
||||||
|
reason := invalidPeers[peer.ID]
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid))
|
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||||
@@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
|
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.setApprovalRequiredFlag(respBody, validPeersMap)
|
h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap)
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, respBody)
|
util.WriteJSONObject(r.Context(), w, respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) {
|
||||||
for _, peer := range respBody {
|
for _, peer := range respBody {
|
||||||
_, ok := approvedPeersMap[peer.Id]
|
_, ok := validPeersMap[peer.Id]
|
||||||
if !ok {
|
if !ok {
|
||||||
peer.ApprovalRequired = true
|
peer.ApprovalRequired = true
|
||||||
|
|
||||||
|
reason := invalidPeersMap[peer.Id]
|
||||||
|
peer.DisapprovalReason = &reason
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||||
@@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
|
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer {
|
||||||
osVersion := peer.Meta.OSVersion
|
osVersion := peer.Meta.OSVersion
|
||||||
if osVersion == "" {
|
if osVersion == "" {
|
||||||
osVersion = peer.Meta.Core
|
osVersion = peer.Meta.Core
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Peer{
|
apiPeer := &api.Peer{
|
||||||
CreatedAt: peer.CreatedAt,
|
CreatedAt: peer.CreatedAt,
|
||||||
Id: peer.ID,
|
Id: peer.ID,
|
||||||
Name: peer.Name,
|
Name: peer.Name,
|
||||||
@@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
|||||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
Ephemeral: peer.Ephemeral,
|
Ephemeral: peer.Ephemeral,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !approved {
|
||||||
|
apiPeer.DisapprovalReason = &reason
|
||||||
|
}
|
||||||
|
|
||||||
|
return apiPeer
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
|
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
||||||
var err error
|
var err error
|
||||||
var groups []*types.Group
|
var groups []*types.Group
|
||||||
var peers []*nbpeer.Peer
|
var peers []*nbpeer.Peer
|
||||||
@@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
|||||||
|
|
||||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
|
validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return validPeers, invalidPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockIntegratedValidator struct {
|
type MockIntegratedValidator struct {
|
||||||
@@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
|
|||||||
return validatedPeers, nil
|
return validatedPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) {
|
||||||
|
return make(map[string]string), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
|
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ type IntegratedValidator interface {
|
|||||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
|
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
|
||||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
||||||
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||||
|
GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error)
|
||||||
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
|
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
|
||||||
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
|
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
|
||||||
Stop(ctx context.Context)
|
Stop(ctx context.Context)
|
||||||
|
|||||||
@@ -190,17 +190,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
|
|||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
||||||
account, err := am.GetAccountFunc(ctx, accountID)
|
account, err := am.GetAccountFunc(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
approvedPeers := make(map[string]struct{})
|
approvedPeers := make(map[string]struct{})
|
||||||
for id := range account.Peers {
|
for id := range account.Peers {
|
||||||
approvedPeers[id] = struct{}{}
|
approvedPeers[id] = struct{}{}
|
||||||
}
|
}
|
||||||
return approvedPeers, nil
|
return approvedPeers, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroup mock implementation of GetGroup from server.AccountManager interface
|
// GetGroup mock implementation of GetGroup from server.AccountManager interface
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -22,6 +23,7 @@ type Manager interface {
|
|||||||
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
|
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
|
||||||
|
|
||||||
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
|
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
|
||||||
|
SetAccountManager(accountManager account.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
|
|||||||
|
|
||||||
return permissions, nil
|
return permissions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
account "github.com/netbirdio/netbird/management/server/account"
|
||||||
modules "github.com/netbirdio/netbird/management/server/permissions/modules"
|
modules "github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
|
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
|
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||||
@@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountManager mocks base method.
|
||||||
|
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetAccountManager", accountManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAccountManager indicates an expected call of SetAccountManager.
|
||||||
|
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateAccountAccess mocks base method.
|
// ValidateAccountAccess mocks base method.
|
||||||
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
|
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
|||||||
peer *nbpeer.Peer
|
peer *nbpeer.Peer
|
||||||
customZone nbdns.CustomZone
|
customZone nbdns.CustomZone
|
||||||
peersToConnect []*nbpeer.Peer
|
peersToConnect []*nbpeer.Peer
|
||||||
|
expiredPeers []*nbpeer.Peer
|
||||||
expectedRecords []nbdns.SimpleRecord
|
expectedRecords []nbdns.SimpleRecord
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
peersToConnect: []*nbpeer.Peer{},
|
peersToConnect: []*nbpeer.Peer{},
|
||||||
|
expiredPeers: []*nbpeer.Peer{},
|
||||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||||
expectedRecords: []nbdns.SimpleRecord{
|
expectedRecords: []nbdns.SimpleRecord{
|
||||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||||
@@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return peers
|
return peers
|
||||||
}(),
|
}(),
|
||||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
expiredPeers: []*nbpeer.Peer{},
|
||||||
|
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||||
expectedRecords: func() []nbdns.SimpleRecord {
|
expectedRecords: func() []nbdns.SimpleRecord {
|
||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
||||||
@@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
|||||||
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
|
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
|
||||||
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
|
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
|
||||||
},
|
},
|
||||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
expiredPeers: []*nbpeer.Peer{},
|
||||||
|
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||||
expectedRecords: []nbdns.SimpleRecord{
|
expectedRecords: []nbdns.SimpleRecord{
|
||||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
@@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
|||||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "expired peers are included in DNS entries",
|
||||||
|
customZone: nbdns.CustomZone{
|
||||||
|
Domain: "netbird.cloud.",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
|
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||||
|
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
|
||||||
|
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peersToConnect: []*nbpeer.Peer{
|
||||||
|
{ID: "peer1", IP: net.ParseIP("10.0.0.1")},
|
||||||
|
},
|
||||||
|
expiredPeers: []*nbpeer.Peer{
|
||||||
|
{ID: "expired-peer", IP: net.ParseIP("10.0.0.99")},
|
||||||
|
},
|
||||||
|
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||||
|
expectedRecords: []nbdns.SimpleRecord{
|
||||||
|
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
|
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
|
||||||
|
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
|
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers)
|
||||||
assert.Equal(t, len(tt.expectedRecords), len(result))
|
assert.Equal(t, len(tt.expectedRecords), len(result))
|
||||||
assert.ElementsMatch(t, tt.expectedRecords, result)
|
assert.ElementsMatch(t, tt.expectedRecords, result)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -136,9 +136,8 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
|
|
||||||
if dnsManagementStatus {
|
if dnsManagementStatus {
|
||||||
var zones []nbdns.CustomZone
|
var zones []nbdns.CustomZone
|
||||||
|
|
||||||
if peersCustomZone.Domain != "" {
|
if peersCustomZone.Domain != "" {
|
||||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
|
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||||
zones = append(zones, nbdns.CustomZone{
|
zones = append(zones, nbdns.CustomZone{
|
||||||
Domain: peersCustomZone.Domain,
|
Domain: peersCustomZone.Domain,
|
||||||
Records: records,
|
Records: records,
|
||||||
@@ -148,14 +147,6 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// nm := GetNetworkMap()
|
|
||||||
// nm.Peers = peersToConnectIncludingRouters
|
|
||||||
// nm.Network = a.Network.Copy()
|
|
||||||
// nm.Routes = slices.Concat(networkResourcesRoutes, routesUpdate)
|
|
||||||
// nm.DNSConfig = dnsUpdate
|
|
||||||
// nm.OfflinePeers = expiredPeers
|
|
||||||
// nm.FirewallRules = firewallRules
|
|
||||||
// nm.RoutesFirewallRules = slices.Concat(networkResourcesFirewallRules, routesFirewallRules)
|
|
||||||
nm := &NetworkMap{
|
nm := &NetworkMap{
|
||||||
Peers: peersToConnectIncludingRouters,
|
Peers: peersToConnectIncludingRouters,
|
||||||
Network: a.Network.Copy(),
|
Network: a.Network.Copy(),
|
||||||
@@ -929,7 +920,7 @@ func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{})
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
|
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||||
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
||||||
peerIPs := make(map[string]struct{})
|
peerIPs := make(map[string]struct{})
|
||||||
|
|
||||||
@@ -940,6 +931,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
|
|||||||
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, expiredPeer := range expiredPeers {
|
||||||
|
peerIPs[expiredPeer.IP.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
if _, exists := peerIPs[record.RData]; exists {
|
if _, exists := peerIPs[record.RData]; exists {
|
||||||
filteredRecords = append(filteredRecords, record)
|
filteredRecords = append(filteredRecords, record)
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then
|
|||||||
NETBIRD_RELEASE=latest
|
NETBIRD_RELEASE=latest
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
TAG_NAME=""
|
||||||
|
|
||||||
get_release() {
|
get_release() {
|
||||||
local RELEASE=$1
|
local RELEASE=$1
|
||||||
if [ "$RELEASE" = "latest" ]; then
|
if [ "$RELEASE" = "latest" ]; then
|
||||||
@@ -38,17 +40,19 @@ get_release() {
|
|||||||
local TAG="tags/${RELEASE}"
|
local TAG="tags/${RELEASE}"
|
||||||
local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
|
local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
|
||||||
fi
|
fi
|
||||||
|
OUTPUT=""
|
||||||
if [ -n "$GITHUB_TOKEN" ]; then
|
if [ -n "$GITHUB_TOKEN" ]; then
|
||||||
curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \
|
OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}")
|
||||||
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
|
|
||||||
else
|
else
|
||||||
curl -s "${URL}" \
|
OUTPUT=$(curl -s "${URL}")
|
||||||
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
|
|
||||||
fi
|
fi
|
||||||
|
TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1)
|
||||||
|
echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+'
|
||||||
}
|
}
|
||||||
|
|
||||||
download_release_binary() {
|
download_release_binary() {
|
||||||
VERSION=$(get_release "$NETBIRD_RELEASE")
|
VERSION=$(get_release "$NETBIRD_RELEASE")
|
||||||
|
echo "Using the following tag name for binary installation: ${TAG_NAME}"
|
||||||
BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download"
|
BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download"
|
||||||
BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz"
|
BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz"
|
||||||
|
|
||||||
|
|||||||
@@ -463,6 +463,9 @@ components:
|
|||||||
description: (Cloud only) Indicates whether peer needs approval
|
description: (Cloud only) Indicates whether peer needs approval
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
disapproval_reason:
|
||||||
|
description: (Cloud only) Reason why the peer requires approval
|
||||||
|
type: string
|
||||||
country_code:
|
country_code:
|
||||||
$ref: '#/components/schemas/CountryCode'
|
$ref: '#/components/schemas/CountryCode'
|
||||||
city_name:
|
city_name:
|
||||||
|
|||||||
@@ -1037,6 +1037,9 @@ type Peer struct {
|
|||||||
// CreatedAt Peer creation date (UTC)
|
// CreatedAt Peer creation date (UTC)
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
|
// DisapprovalReason (Cloud only) Reason why the peer requires approval
|
||||||
|
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
|
||||||
|
|
||||||
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||||
DnsLabel string `json:"dns_label"`
|
DnsLabel string `json:"dns_label"`
|
||||||
|
|
||||||
@@ -1124,6 +1127,9 @@ type PeerBatch struct {
|
|||||||
// CreatedAt Peer creation date (UTC)
|
// CreatedAt Peer creation date (UTC)
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
|
// DisapprovalReason (Cloud only) Reason why the peer requires approval
|
||||||
|
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
|
||||||
|
|
||||||
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||||
DnsLabel string `json:"dns_label"`
|
DnsLabel string `json:"dns_label"`
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ var (
|
|||||||
|
|
||||||
startPprof()
|
startPprof()
|
||||||
|
|
||||||
opts, certManager, err := getTLSConfigurations()
|
opts, certManager, tlsConfig, err := getTLSConfigurations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -132,7 +132,7 @@ var (
|
|||||||
|
|
||||||
// Start the main server - always serve HTTP with WebSocket proxy support
|
// Start the main server - always serve HTTP with WebSocket proxy support
|
||||||
// If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager
|
// If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager
|
||||||
if certManager == nil {
|
if tlsConfig == nil {
|
||||||
// Without TLS, serve plain HTTP
|
// Without TLS, serve plain HTTP
|
||||||
httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort))
|
httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -140,9 +140,10 @@ var (
|
|||||||
}
|
}
|
||||||
log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String())
|
log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String())
|
||||||
serveHTTP(httpListener, grpcRootHandler)
|
serveHTTP(httpListener, grpcRootHandler)
|
||||||
} else if signalPort != 443 {
|
} else if certManager == nil || signalPort != 443 {
|
||||||
// With TLS but not on port 443, serve HTTPS
|
// Serve HTTPS if not already handled by startServerWithCertManager
|
||||||
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig())
|
// (custom certificates or Let's Encrypt with custom port)
|
||||||
|
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -202,7 +203,7 @@ func startPprof() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
|
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
certManager *autocert.Manager
|
certManager *autocert.Manager
|
||||||
@@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
|
|||||||
|
|
||||||
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
|
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
|
||||||
log.Infof("running without TLS")
|
log.Infof("running without TLS")
|
||||||
return nil, nil, nil
|
return nil, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if signalLetsencryptDomain != "" {
|
if signalLetsencryptDomain != "" {
|
||||||
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
|
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, certManager, err
|
return nil, certManager, nil, err
|
||||||
}
|
}
|
||||||
tlsConfig = certManager.TLSConfig()
|
tlsConfig = certManager.TLSConfig()
|
||||||
log.Infof("setting up TLS with LetsEncrypt.")
|
log.Infof("setting up TLS with LetsEncrypt.")
|
||||||
} else {
|
} else {
|
||||||
if signalCertFile == "" || signalCertKey == "" {
|
if signalCertFile == "" || signalCertKey == "" {
|
||||||
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
||||||
return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
|
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("cannot load TLS credentials: %v", err)
|
log.Errorf("cannot load TLS credentials: %v", err)
|
||||||
return nil, certManager, err
|
return nil, certManager, nil, err
|
||||||
}
|
}
|
||||||
log.Infof("setting up TLS with custom certificates.")
|
log.Infof("setting up TLS with custom certificates.")
|
||||||
}
|
}
|
||||||
|
|
||||||
transportCredentials := credentials.NewTLS(tlsConfig)
|
transportCredentials := credentials.NewTLS(tlsConfig)
|
||||||
|
|
||||||
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err
|
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {
|
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {
|
||||||
|
|||||||
Reference in New Issue
Block a user