mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-28 05:06:38 +00:00
Compare commits
24 Commits
deploy/pee
...
fix/proxy_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9bcdcf149 | ||
|
|
d39814f173 | ||
|
|
4a2429eb1c | ||
|
|
de2e6557ad | ||
|
|
650bca7ca8 | ||
|
|
570e28d227 | ||
|
|
272ade07a8 | ||
|
|
263abe4862 | ||
|
|
ceee421a05 | ||
|
|
0a75da6fb7 | ||
|
|
920877964f | ||
|
|
2e0047daea | ||
|
|
ce0718fcb5 | ||
|
|
c590518e0c | ||
|
|
f309b120cd | ||
|
|
7357a9954c | ||
|
|
13b63eebc1 | ||
|
|
735ed7ab34 | ||
|
|
961d9198ef | ||
|
|
df4ca01848 | ||
|
|
4e7c17756c | ||
|
|
6a4935139d | ||
|
|
35dd991776 | ||
|
|
3598418206 |
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -7,17 +7,7 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
|
||||||
- 'go.mod'
|
|
||||||
- 'go.sum'
|
|
||||||
- '.goreleaser.yml'
|
|
||||||
- '.goreleaser_ui.yaml'
|
|
||||||
- '.goreleaser_ui_darwin.yaml'
|
|
||||||
- '.github/workflows/release.yml'
|
|
||||||
- 'release_files/**'
|
|
||||||
- '**/Dockerfile'
|
|
||||||
- '**/Dockerfile.*'
|
|
||||||
- 'client/ui/**'
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.11"
|
SIGN_PIPE_VER: "v0.0.11"
|
||||||
@@ -106,6 +96,27 @@ jobs:
|
|||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 3
|
retention-days: 3
|
||||||
|
-
|
||||||
|
name: upload linux packages
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: linux-packages
|
||||||
|
path: dist/netbird_linux**
|
||||||
|
retention-days: 3
|
||||||
|
-
|
||||||
|
name: upload windows packages
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: windows-packages
|
||||||
|
path: dist/netbird_windows**
|
||||||
|
retention-days: 3
|
||||||
|
-
|
||||||
|
name: upload macos packages
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: macos-packages
|
||||||
|
path: dist/netbird_darwin**
|
||||||
|
retention-days: 3
|
||||||
|
|
||||||
release_ui:
|
release_ui:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
package android
|
package android
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -14,6 +16,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -59,6 +62,7 @@ type Client struct {
|
|||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||||
|
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
cfgFile: cfgFile,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
@@ -97,7 +101,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -122,7 +127,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ const (
|
|||||||
preSharedKeyFlag = "preshared-key"
|
preSharedKeyFlag = "preshared-key"
|
||||||
interfaceNameFlag = "interface-name"
|
interfaceNameFlag = "interface-name"
|
||||||
wireguardPortFlag = "wireguard-port"
|
wireguardPortFlag = "wireguard-port"
|
||||||
|
networkMonitorFlag = "network-monitor"
|
||||||
disableAutoConnectFlag = "disable-auto-connect"
|
disableAutoConnectFlag = "disable-auto-connect"
|
||||||
serverSSHAllowedFlag = "allow-server-ssh"
|
serverSSHAllowedFlag = "allow-server-ssh"
|
||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
@@ -62,6 +63,7 @@ var (
|
|||||||
serverSSHAllowed bool
|
serverSSHAllowed bool
|
||||||
interfaceName string
|
interfaceName string
|
||||||
wireguardPort uint16
|
wireguardPort uint16
|
||||||
|
networkMonitor bool
|
||||||
serviceName string
|
serviceName string
|
||||||
autoConnectDisabled bool
|
autoConnectDisabled bool
|
||||||
extraIFaceBlackList []string
|
extraIFaceBlackList []string
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||||
|
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,6 +117,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.WireguardPort = &p
|
ic.WireguardPort = &p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(networkMonitorFlag).Changed {
|
||||||
|
ic.NetworkMonitor = &networkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
ic.PreSharedKey = &preSharedKey
|
ic.PreSharedKey = &preSharedKey
|
||||||
}
|
}
|
||||||
@@ -147,7 +152,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
|
||||||
|
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||||
|
return connectClient.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
@@ -226,6 +233,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.WireguardPort = &wp
|
loginRequest.WireguardPort = &wp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(networkMonitorFlag).Changed {
|
||||||
|
loginRequest.NetworkMonitor = &networkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
|
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
|
|||||||
@@ -87,12 +87,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertRoutingRule inserts an iptable rule
|
// insertRoutingRule inserts an iptables rule
|
||||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||||
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
|
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||||
existingRule, found := i.rules[ruleKey]
|
existingRule, found := i.rules[ruleKey]
|
||||||
if found {
|
if found {
|
||||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||||
@@ -326,9 +326,9 @@ func (i *routerManager) createChain(table, newChain string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// genRuleSpec generates rule specification with comment identifier
|
// genRuleSpec generates rule specification
|
||||||
func genRuleSpec(jump, id, source, destination string) []string {
|
func genRuleSpec(jump, source, destination string) []string {
|
||||||
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
|
return []string{"-s", source, "-d", destination, "-j", jump}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getIptablesRuleType(table string) string {
|
func getIptablesRuleType(table string) string {
|
||||||
|
|||||||
@@ -51,14 +51,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
Destination: "100.100.100.0/24",
|
Destination: "100.100.100.0/24",
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
|
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
|
||||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
|
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
@@ -92,7 +90,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
require.NoError(t, err, "forwarding pair should be inserted")
|
||||||
|
|
||||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||||
@@ -103,7 +101,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||||
|
|
||||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||||
@@ -114,7 +112,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
@@ -130,7 +128,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
@@ -167,25 +165,25 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|||||||
@@ -64,15 +64,18 @@ func manageFirewallRule(ruleName string, action action, extraArgs ...string) err
|
|||||||
if action == addRule {
|
if action == addRule {
|
||||||
args = append(args, extraArgs...)
|
args = append(args, extraArgs...)
|
||||||
}
|
}
|
||||||
|
netshCmd := GetSystem32Command("netsh")
|
||||||
cmd := exec.Command("netsh", args...)
|
cmd := exec.Command(netshCmd, args...)
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
return cmd.Run()
|
return cmd.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
func isWindowsFirewallReachable() bool {
|
func isWindowsFirewallReachable() bool {
|
||||||
args := []string{"advfirewall", "show", "allprofiles", "state"}
|
args := []string{"advfirewall", "show", "allprofiles", "state"}
|
||||||
cmd := exec.Command("netsh", args...)
|
|
||||||
|
netshCmd := GetSystem32Command("netsh")
|
||||||
|
|
||||||
|
cmd := exec.Command(netshCmd, args...)
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
|
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
@@ -87,8 +90,23 @@ func isWindowsFirewallReachable() bool {
|
|||||||
func isFirewallRuleActive(ruleName string) bool {
|
func isFirewallRuleActive(ruleName string) bool {
|
||||||
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
||||||
|
|
||||||
cmd := exec.Command("netsh", args...)
|
netshCmd := GetSystem32Command("netsh")
|
||||||
|
|
||||||
|
cmd := exec.Command(netshCmd, args...)
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
|
||||||
|
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
|
||||||
|
func GetSystem32Command(command string) string {
|
||||||
|
_, err := exec.LookPath(command)
|
||||||
|
if err == nil {
|
||||||
|
return command
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
|
||||||
|
|
||||||
|
return "C:\\windows\\system32\\" + command + ".exe"
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -48,6 +50,7 @@ type ConfigInput struct {
|
|||||||
RosenpassPermissive *bool
|
RosenpassPermissive *bool
|
||||||
InterfaceName *string
|
InterfaceName *string
|
||||||
WireguardPort *int
|
WireguardPort *int
|
||||||
|
NetworkMonitor *bool
|
||||||
DisableAutoConnect *bool
|
DisableAutoConnect *bool
|
||||||
ExtraIFaceBlackList []string
|
ExtraIFaceBlackList []string
|
||||||
}
|
}
|
||||||
@@ -61,6 +64,7 @@ type Config struct {
|
|||||||
AdminURL *url.URL
|
AdminURL *url.URL
|
||||||
WgIface string
|
WgIface string
|
||||||
WgPort int
|
WgPort int
|
||||||
|
NetworkMonitor bool
|
||||||
IFaceBlackList []string
|
IFaceBlackList []string
|
||||||
DisableIPv6Discovery bool
|
DisableIPv6Discovery bool
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
@@ -100,6 +104,14 @@ func ReadConfig(configPath string) (*Config, error) {
|
|||||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// initialize through apply() without changes
|
||||||
|
if changed, err := config.apply(ConfigInput{}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if changed {
|
||||||
|
if err = WriteOutConfig(configPath, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
@@ -152,79 +164,15 @@ func WriteOutConfig(path string, config *Config) error {
|
|||||||
|
|
||||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||||
func createNewConfig(input ConfigInput) (*Config, error) {
|
func createNewConfig(input ConfigInput) (*Config, error) {
|
||||||
wgKey := generateKey()
|
|
||||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
SSHKey: string(pem),
|
// defaults to false only for new (post 0.26) configurations
|
||||||
PrivateKey: wgKey,
|
ServerSSHAllowed: util.False(),
|
||||||
IFaceBlackList: []string{},
|
|
||||||
DisableIPv6Discovery: false,
|
|
||||||
NATExternalIPs: input.NATExternalIPs,
|
|
||||||
CustomDNSAddress: string(input.CustomDNSAddress),
|
|
||||||
ServerSSHAllowed: util.False(),
|
|
||||||
DisableAutoConnect: false,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
if _, err := config.apply(input); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.ManagementURL = defaultManagementURL
|
|
||||||
if input.ManagementURL != "" {
|
|
||||||
URL, err := parseURL("Management URL", input.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
config.ManagementURL = URL
|
|
||||||
}
|
|
||||||
|
|
||||||
config.WgPort = iface.DefaultWgPort
|
|
||||||
if input.WireguardPort != nil {
|
|
||||||
config.WgPort = *input.WireguardPort
|
|
||||||
}
|
|
||||||
|
|
||||||
config.WgIface = iface.WgInterfaceDefault
|
|
||||||
if input.InterfaceName != nil {
|
|
||||||
config.WgIface = *input.InterfaceName
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.PreSharedKey != nil {
|
|
||||||
config.PreSharedKey = *input.PreSharedKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RosenpassEnabled != nil {
|
|
||||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RosenpassPermissive != nil {
|
|
||||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ServerSSHAllowed != nil {
|
|
||||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.AdminURL = defaultAdminURL
|
|
||||||
if input.AdminURL != "" {
|
|
||||||
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
config.AdminURL = newURL
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:gocritic
|
|
||||||
config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...)
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,104 +183,12 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
refresh := false
|
updated, err := config.apply(input)
|
||||||
|
if err != nil {
|
||||||
if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
|
return nil, err
|
||||||
log.Infof("new Management URL provided, updated to %s (old value %s)",
|
|
||||||
input.ManagementURL, config.ManagementURL)
|
|
||||||
newURL, err := parseURL("Management URL", input.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
config.ManagementURL = newURL
|
|
||||||
refresh = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
|
if updated {
|
||||||
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
|
|
||||||
input.AdminURL, config.AdminURL)
|
|
||||||
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
config.AdminURL = newURL
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
|
||||||
log.Infof("new pre-shared key provided, replacing old key")
|
|
||||||
config.PreSharedKey = *input.PreSharedKey
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.SSHKey == "" {
|
|
||||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
config.SSHKey = string(pem)
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.WgPort == 0 {
|
|
||||||
config.WgPort = iface.DefaultWgPort
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.WireguardPort != nil {
|
|
||||||
config.WgPort = *input.WireguardPort
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.InterfaceName != nil {
|
|
||||||
config.WgIface = *input.InterfaceName
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
|
|
||||||
config.NATExternalIPs = input.NATExternalIPs
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.CustomDNSAddress != nil {
|
|
||||||
config.CustomDNSAddress = string(input.CustomDNSAddress)
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RosenpassEnabled != nil {
|
|
||||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RosenpassPermissive != nil {
|
|
||||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.DisableAutoConnect != nil {
|
|
||||||
config.DisableAutoConnect = *input.DisableAutoConnect
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ServerSSHAllowed != nil {
|
|
||||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ServerSSHAllowed == nil {
|
|
||||||
config.ServerSSHAllowed = util.True()
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(input.ExtraIFaceBlackList) > 0 {
|
|
||||||
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
|
|
||||||
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if refresh {
|
|
||||||
// since we have new management URL, we need to update config file
|
|
||||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -341,6 +197,169 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||||
|
if config.ManagementURL == nil {
|
||||||
|
log.Infof("using default Management URL %s", DefaultManagementURL)
|
||||||
|
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() {
|
||||||
|
log.Infof("new Management URL provided, updated to %#v (old value %#v)",
|
||||||
|
input.ManagementURL, config.ManagementURL.String())
|
||||||
|
URL, err := parseURL("Management URL", input.ManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
config.ManagementURL = URL
|
||||||
|
updated = true
|
||||||
|
} else if config.ManagementURL == nil {
|
||||||
|
log.Infof("using default Management URL %s", DefaultManagementURL)
|
||||||
|
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AdminURL == nil {
|
||||||
|
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
||||||
|
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() {
|
||||||
|
log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)",
|
||||||
|
input.AdminURL, config.AdminURL.String())
|
||||||
|
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
||||||
|
if err != nil {
|
||||||
|
return updated, err
|
||||||
|
}
|
||||||
|
config.AdminURL = newURL
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.PrivateKey == "" {
|
||||||
|
log.Infof("generated new Wireguard key")
|
||||||
|
config.PrivateKey = generateKey()
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.SSHKey == "" {
|
||||||
|
log.Infof("generated new SSH key")
|
||||||
|
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
config.SSHKey = string(pem)
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.WireguardPort != nil && *input.WireguardPort != config.WgPort {
|
||||||
|
log.Infof("updating Wireguard port %d (old value %d)",
|
||||||
|
*input.WireguardPort, config.WgPort)
|
||||||
|
config.WgPort = *input.WireguardPort
|
||||||
|
updated = true
|
||||||
|
} else if config.WgPort == 0 {
|
||||||
|
config.WgPort = iface.DefaultWgPort
|
||||||
|
log.Infof("using default Wireguard port %d", config.WgPort)
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
||||||
|
log.Infof("updating Wireguard interface %#v (old value %#v)",
|
||||||
|
*input.InterfaceName, config.WgIface)
|
||||||
|
config.WgIface = *input.InterfaceName
|
||||||
|
updated = true
|
||||||
|
} else if config.WgIface == "" {
|
||||||
|
config.WgIface = iface.WgInterfaceDefault
|
||||||
|
log.Infof("using default Wireguard interface %s", config.WgIface)
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.NATExternalIPs != nil && !reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) {
|
||||||
|
log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])",
|
||||||
|
strings.Join(input.NATExternalIPs, " "),
|
||||||
|
strings.Join(config.NATExternalIPs, " "))
|
||||||
|
config.NATExternalIPs = input.NATExternalIPs
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey {
|
||||||
|
log.Infof("new pre-shared key provided, replacing old key")
|
||||||
|
config.PreSharedKey = *input.PreSharedKey
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled {
|
||||||
|
log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled)
|
||||||
|
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive {
|
||||||
|
log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive)
|
||||||
|
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.NetworkMonitor != nil && *input.NetworkMonitor != config.NetworkMonitor {
|
||||||
|
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
|
||||||
|
config.NetworkMonitor = *input.NetworkMonitor
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
|
||||||
|
log.Infof("updating custom DNS address %#v (old value %#v)",
|
||||||
|
string(input.CustomDNSAddress), config.CustomDNSAddress)
|
||||||
|
config.CustomDNSAddress = string(input.CustomDNSAddress)
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.IFaceBlackList) == 0 {
|
||||||
|
log.Infof("filling in interface blacklist with defaults: [ %s ]",
|
||||||
|
strings.Join(defaultInterfaceBlacklist, " "))
|
||||||
|
config.IFaceBlackList = append(config.IFaceBlackList, defaultInterfaceBlacklist...)
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(input.ExtraIFaceBlackList) > 0 {
|
||||||
|
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
|
||||||
|
log.Infof("adding new entry to interface blacklist: %s", iFace)
|
||||||
|
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect {
|
||||||
|
if *input.DisableAutoConnect {
|
||||||
|
log.Infof("turning off automatic connection on startup")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling automatic connection on startup")
|
||||||
|
}
|
||||||
|
config.DisableAutoConnect = *input.DisableAutoConnect
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
|
||||||
|
if *input.ServerSSHAllowed {
|
||||||
|
log.Infof("enabling SSH server")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH server")
|
||||||
|
}
|
||||||
|
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||||
|
updated = true
|
||||||
|
} else if config.ServerSSHAllowed == nil {
|
||||||
|
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||||
|
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||||
|
config.ServerSSHAllowed = util.True()
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
// parseURL parses and validates a service URL
|
// parseURL parses and validates a service URL
|
||||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
@@ -29,30 +30,45 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
type ConnectClient struct {
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
ctx context.Context
|
||||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
|
config *Config
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
engine *Engine
|
||||||
|
engineMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunClientWithProbes runs the client's main logic with probes attached
|
func NewConnectClient(
|
||||||
func RunClientWithProbes(
|
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *Config,
|
config *Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
|
||||||
|
) *ConnectClient {
|
||||||
|
return &ConnectClient{
|
||||||
|
ctx: ctx,
|
||||||
|
config: config,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
engineMutex: sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run with main logic.
|
||||||
|
func (c *ConnectClient) Run() error {
|
||||||
|
return c.run(MobileDependency{}, nil, nil, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunWithProbes runs the client's main logic with probes attached
|
||||||
|
func (c *ConnectClient) RunWithProbes(
|
||||||
mgmProbe *Probe,
|
mgmProbe *Probe,
|
||||||
signalProbe *Probe,
|
signalProbe *Probe,
|
||||||
relayProbe *Probe,
|
relayProbe *Probe,
|
||||||
wgProbe *Probe,
|
wgProbe *Probe,
|
||||||
engineChan chan<- *Engine,
|
|
||||||
) error {
|
) error {
|
||||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
|
return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunClientMobile with main logic on mobile system
|
// RunOnAndroid with main logic on mobile system
|
||||||
func RunClientMobile(
|
func (c *ConnectClient) RunOnAndroid(
|
||||||
ctx context.Context,
|
|
||||||
config *Config,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
tunAdapter iface.TunAdapter,
|
tunAdapter iface.TunAdapter,
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
@@ -67,13 +83,10 @@ func RunClientMobile(
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
}
|
}
|
||||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
|
return c.run(mobileDependency, nil, nil, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunClientiOS(
|
func (c *ConnectClient) RunOniOS(
|
||||||
ctx context.Context,
|
|
||||||
config *Config,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
fileDescriptor int32,
|
fileDescriptor int32,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsManager dns.IosDnsManager,
|
dnsManager dns.IosDnsManager,
|
||||||
@@ -83,19 +96,15 @@ func RunClientiOS(
|
|||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
}
|
}
|
||||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
|
return c.run(mobileDependency, nil, nil, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runClient(
|
func (c *ConnectClient) run(
|
||||||
ctx context.Context,
|
|
||||||
config *Config,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
mobileDependency MobileDependency,
|
mobileDependency MobileDependency,
|
||||||
mgmProbe *Probe,
|
mgmProbe *Probe,
|
||||||
signalProbe *Probe,
|
signalProbe *Probe,
|
||||||
relayProbe *Probe,
|
relayProbe *Probe,
|
||||||
wgProbe *Probe,
|
wgProbe *Probe,
|
||||||
engineChan chan<- *Engine,
|
|
||||||
) error {
|
) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -107,7 +116,7 @@ func runClient(
|
|||||||
|
|
||||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
// Check if client was not shut down in a clean way and restore DNS config if required.
|
||||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||||
if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
|
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
|
||||||
log.Errorf("checking unclean shutdown error: %s", err)
|
log.Errorf("checking unclean shutdown error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +130,7 @@ func runClient(
|
|||||||
Clock: backoff.SystemClock,
|
Clock: backoff.SystemClock,
|
||||||
}
|
}
|
||||||
|
|
||||||
state := CtxGetState(ctx)
|
state := CtxGetState(c.ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
s, err := state.Status()
|
s, err := state.Status()
|
||||||
if err != nil || s != StatusNeedsLogin {
|
if err != nil || s != StatusNeedsLogin {
|
||||||
@@ -130,49 +139,49 @@ func runClient(
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
wrapErr := state.Wrap
|
wrapErr := state.Wrap
|
||||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
var mgmTlsEnabled bool
|
||||||
if config.ManagementURL.Scheme == "https" {
|
if c.config.ManagementURL.Scheme == "https" {
|
||||||
mgmTlsEnabled = true
|
mgmTlsEnabled = true
|
||||||
}
|
}
|
||||||
|
|
||||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-c.ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
state.Set(StatusConnecting)
|
state.Set(StatusConnecting)
|
||||||
|
|
||||||
engineCtx, cancel := context.WithCancel(ctx)
|
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
statusRecorder.MarkManagementDisconnected(state.err)
|
c.statusRecorder.MarkManagementDisconnected(state.err)
|
||||||
statusRecorder.CleanLocalPeerState()
|
c.statusRecorder.CleanLocalPeerState()
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
|
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||||
}
|
}
|
||||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
|
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||||
mgmClient.SetConnStateListener(mgmNotifier)
|
mgmClient.SetConnStateListener(mgmNotifier)
|
||||||
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||||
defer func() {
|
defer func() {
|
||||||
err = mgmClient.Close()
|
err = mgmClient.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -190,7 +199,7 @@ func runClient(
|
|||||||
}
|
}
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
statusRecorder.MarkManagementConnected()
|
c.statusRecorder.MarkManagementConnected()
|
||||||
|
|
||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
@@ -199,18 +208,18 @@ func runClient(
|
|||||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||||
}
|
}
|
||||||
|
|
||||||
statusRecorder.UpdateLocalPeerState(localPeerState)
|
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
|
||||||
signalURL := fmt.Sprintf("%s://%s",
|
signalURL := fmt.Sprintf("%s://%s",
|
||||||
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
|
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
|
||||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
||||||
)
|
)
|
||||||
|
|
||||||
statusRecorder.UpdateSignalAddress(signalURL)
|
c.statusRecorder.UpdateSignalAddress(signalURL)
|
||||||
|
|
||||||
statusRecorder.MarkSignalDisconnected(nil)
|
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||||
defer func() {
|
defer func() {
|
||||||
statusRecorder.MarkSignalDisconnected(state.err)
|
c.statusRecorder.MarkSignalDisconnected(state.err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||||
@@ -226,42 +235,38 @@ func runClient(
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
|
signalNotifier := statusRecorderToSignalConnStateNotifier(c.statusRecorder)
|
||||||
signalClient.SetConnStateListener(signalNotifier)
|
signalClient.SetConnStateListener(signalNotifier)
|
||||||
|
|
||||||
statusRecorder.MarkSignalConnected()
|
c.statusRecorder.MarkSignalConnected()
|
||||||
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
|
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
c.engineMutex.Lock()
|
||||||
err = engine.Start()
|
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||||
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
|
err = c.engine.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
if engineChan != nil {
|
|
||||||
engineChan <- engine
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
statusRecorder.ClientTeardown()
|
c.statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
backOff.Reset()
|
backOff.Reset()
|
||||||
|
|
||||||
if engineChan != nil {
|
err = c.engine.Stop()
|
||||||
engineChan <- nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.Stop()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed stopping engine %v", err)
|
log.Errorf("failed stopping engine %v", err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@@ -276,7 +281,7 @@ func runClient(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
statusRecorder.ClientStart()
|
c.statusRecorder.ClientStart()
|
||||||
err = backoff.Retry(operation, backOff)
|
err = backoff.Retry(operation, backOff)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||||
@@ -288,6 +293,14 @@ func runClient(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConnectClient) Engine() *Engine {
|
||||||
|
var e *Engine
|
||||||
|
c.engineMutex.Lock()
|
||||||
|
e = c.engine
|
||||||
|
c.engineMutex.Unlock()
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||||
engineConf := &EngineConfig{
|
engineConf := &EngineConfig{
|
||||||
@@ -297,6 +310,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: config.WgPort,
|
WgPort: config.WgPort,
|
||||||
|
NetworkMonitor: config.NetworkMonitor,
|
||||||
SSHKey: []byte(config.SSHKey),
|
SSHKey: []byte(config.SSHKey),
|
||||||
NATExternalIPs: config.NATExternalIPs,
|
NATExternalIPs: config.NATExternalIPs,
|
||||||
CustomDNSAddress: config.CustomDNSAddress,
|
CustomDNSAddress: config.CustomDNSAddress,
|
||||||
|
|||||||
@@ -47,24 +47,20 @@ func (f *fileConfigurator) supportCustomPort() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||||
backupFileExist := false
|
backupFileExist := f.isBackupFileExist()
|
||||||
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
|
||||||
if err == nil {
|
|
||||||
backupFileExist = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !config.RouteAll {
|
if !config.RouteAll {
|
||||||
if backupFileExist {
|
if backupFileExist {
|
||||||
err = f.restore()
|
f.repair.stopWatchFileChanges()
|
||||||
|
err := f.restore()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %w", err)
|
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !backupFileExist {
|
if !backupFileExist {
|
||||||
err = f.backup()
|
err := f.backup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
|
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -184,6 +180,11 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fileConfigurator) isBackupFileExist() bool {
|
||||||
|
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
func restoreResolvConfFile() error {
|
func restoreResolvConfFile() error {
|
||||||
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)
|
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)
|
||||||
|
|
||||||
|
|||||||
63
client/internal/dns/hosts_dns_holder.go
Normal file
63
client/internal/dns/hosts_dns_holder.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type hostsDNSHolder struct {
|
||||||
|
unprotectedDNSList map[string]struct{}
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHostsDNSHolder() *hostsDNSHolder {
|
||||||
|
return &hostsDNSHolder{
|
||||||
|
unprotectedDNSList: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hostsDNSHolder) set(list []string) {
|
||||||
|
h.mutex.Lock()
|
||||||
|
h.unprotectedDNSList = make(map[string]struct{})
|
||||||
|
for _, dns := range list {
|
||||||
|
dnsAddr, err := h.normalizeAddress(dns)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.unprotectedDNSList[dnsAddr] = struct{}{}
|
||||||
|
}
|
||||||
|
h.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hostsDNSHolder) get() map[string]struct{} {
|
||||||
|
h.mutex.RLock()
|
||||||
|
l := h.unprotectedDNSList
|
||||||
|
h.mutex.RUnlock()
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:unused
|
||||||
|
func (h *hostsDNSHolder) isContain(upstream string) bool {
|
||||||
|
h.mutex.RLock()
|
||||||
|
defer h.mutex.RUnlock()
|
||||||
|
|
||||||
|
_, ok := h.unprotectedDNSList[upstream]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
|
||||||
|
a, err := netip.ParseAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.Is4() {
|
||||||
|
return fmt.Sprintf("%s:53", addr), nil
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("[%s]:53", addr), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -54,9 +55,8 @@ type DefaultServer struct {
|
|||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
hostsDnsList []string
|
hostsDNSHolder *hostsDNSHolder
|
||||||
hostsDnsListLock sync.Mutex
|
|
||||||
|
|
||||||
// make sense on mobile only
|
// make sense on mobile only
|
||||||
searchDomainNotifier *notifier
|
searchDomainNotifier *notifier
|
||||||
@@ -113,8 +113,8 @@ func NewDefaultServerPermanentUpstream(
|
|||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||||
|
ds.hostsDNSHolder.set(hostsDnsList)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.hostsDnsList = hostsDnsList
|
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
|
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
|
||||||
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
|
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
|
||||||
@@ -147,6 +147,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
},
|
},
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return defaultServer
|
return defaultServer
|
||||||
@@ -202,10 +203,8 @@ func (s *DefaultServer) Stop() {
|
|||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||||
s.hostsDnsListLock.Lock()
|
s.hostsDNSHolder.set(hostsDnsList)
|
||||||
defer s.hostsDnsListLock.Unlock()
|
|
||||||
|
|
||||||
s.hostsDnsList = hostsDnsList
|
|
||||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||||
if ok {
|
if ok {
|
||||||
log.Debugf("on new host DNS config but skip to apply it")
|
log.Debugf("on new host DNS config but skip to apply it")
|
||||||
@@ -374,6 +373,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
s.wgInterface.Address().IP,
|
s.wgInterface.Address().IP,
|
||||||
s.wgInterface.Address().Network,
|
s.wgInterface.Address().Network,
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
|
s.hostsDNSHolder,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||||
@@ -452,9 +452,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
_, found := muxUpdateMap[key]
|
_, found := muxUpdateMap[key]
|
||||||
if !found {
|
if !found {
|
||||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||||
s.hostsDnsListLock.Lock()
|
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
s.hostsDnsListLock.Unlock()
|
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
} else {
|
} else {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
@@ -512,6 +510,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
|
s.service.DeregisterMux(nbdns.RootZone)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
@@ -521,10 +520,15 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||||
|
s.addHostRootZone()
|
||||||
|
}
|
||||||
|
|
||||||
s.updateNSState(nsGroup, err, false)
|
s.updateNSState(nsGroup, err, false)
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -545,6 +549,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
|
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
@@ -562,25 +567,16 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
s.wgInterface.Address().IP,
|
s.wgInterface.Address().IP,
|
||||||
s.wgInterface.Address().Network,
|
s.wgInterface.Address().Network,
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
|
s.hostsDNSHolder,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
|
||||||
for n, ua := range s.hostsDnsList {
|
|
||||||
a, err := netip.ParseAddr(ua)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ipString := ua
|
handler.upstreamServers = make([]string, 0)
|
||||||
if !a.Is4() {
|
for k := range s.hostsDNSHolder.get() {
|
||||||
ipString = fmt.Sprintf("[%s]", ua)
|
handler.upstreamServers = append(handler.upstreamServers, k)
|
||||||
}
|
|
||||||
|
|
||||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
|
|
||||||
}
|
}
|
||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
|
|||||||
84
client/internal/dns/upstream_android.go
Normal file
84
client/internal/dns/upstream_android.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type upstreamResolver struct {
|
||||||
|
*upstreamResolverBase
|
||||||
|
hostsDNSHolder *hostsDNSHolder
|
||||||
|
}
|
||||||
|
|
||||||
|
// newUpstreamResolver in Android we need to distinguish the DNS servers to available through VPN or outside of VPN
|
||||||
|
// In case if the assigned DNS address is available only in the protected network then the resolver will time out at the
|
||||||
|
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
|
||||||
|
func newUpstreamResolver(
|
||||||
|
ctx context.Context,
|
||||||
|
_ string,
|
||||||
|
_ net.IP,
|
||||||
|
_ *net.IPNet,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
|
) (*upstreamResolver, error) {
|
||||||
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||||
|
c := &upstreamResolver{
|
||||||
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
|
hostsDNSHolder: hostsDNSHolder,
|
||||||
|
}
|
||||||
|
upstreamResolverBase.upstreamClient = c
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchange in case of Android if the upstream is a local resolver then we do not need to mark the socket as protected.
|
||||||
|
// In other case the DNS resolvation goes through the VPN, so we need to force to use the
|
||||||
|
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
|
if u.isLocalResolver(upstream) {
|
||||||
|
return u.exchangeWithoutVPN(ctx, upstream, r)
|
||||||
|
} else {
|
||||||
|
return u.exchangeWithinVPN(ctx, upstream, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
|
upstreamExchangeClient := &dns.Client{}
|
||||||
|
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||||
|
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
|
timeout := upstreamTimeout
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
timeout = time.Until(deadline)
|
||||||
|
}
|
||||||
|
dialTimeout := timeout
|
||||||
|
|
||||||
|
nbDialer := nbnet.NewDialer()
|
||||||
|
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
return nbDialer.Control(network, address, c)
|
||||||
|
},
|
||||||
|
Timeout: dialTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamExchangeClient := &dns.Client{
|
||||||
|
Dialer: dialer,
|
||||||
|
}
|
||||||
|
|
||||||
|
return upstreamExchangeClient.Exchange(r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||||
|
if u.hostsDNSHolder.isContain(upstream) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !ios
|
//go:build !android && !ios
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolverNonIOS struct {
|
type upstreamResolver struct {
|
||||||
*upstreamResolverBase
|
*upstreamResolverBase
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,16 +22,17 @@ func newUpstreamResolver(
|
|||||||
_ net.IP,
|
_ net.IP,
|
||||||
_ *net.IPNet,
|
_ *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) (*upstreamResolverNonIOS, error) {
|
_ *hostsDNSHolder,
|
||||||
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||||
nonIOS := &upstreamResolverNonIOS{
|
nonIOS := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
}
|
}
|
||||||
upstreamResolverBase.upstreamClient = nonIOS
|
upstreamResolverBase.upstreamClient = nonIOS
|
||||||
return nonIOS, nil
|
return nonIOS, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverNonIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
upstreamExchangeClient := &dns.Client{}
|
upstreamExchangeClient := &dns.Client{}
|
||||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||||
}
|
}
|
||||||
@@ -28,6 +28,7 @@ func newUpstreamResolver(
|
|||||||
ip net.IP,
|
ip net.IP,
|
||||||
net *net.IPNet,
|
net *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
_ *hostsDNSHolder,
|
||||||
) (*upstreamResolverIOS, error) {
|
) (*upstreamResolverIOS, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil)
|
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil)
|
||||||
resolver.upstreamServers = testCase.InputServers
|
resolver.upstreamServers = testCase.InputServers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
@@ -60,6 +62,9 @@ type EngineConfig struct {
|
|||||||
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
|
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
|
||||||
WgPrivateKey wgtypes.Key
|
WgPrivateKey wgtypes.Key
|
||||||
|
|
||||||
|
// NetworkMonitor is a flag to enable network monitoring
|
||||||
|
NetworkMonitor bool
|
||||||
|
|
||||||
// IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related)
|
// IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related)
|
||||||
IFaceBlackList []string
|
IFaceBlackList []string
|
||||||
DisableIPv6Discovery bool
|
DisableIPv6Discovery bool
|
||||||
@@ -112,12 +117,14 @@ type Engine struct {
|
|||||||
TURNs []*stun.URI
|
TURNs []*stun.URI
|
||||||
|
|
||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
clientRoutes map[string][]*route.Route
|
clientRoutes route.HAMap
|
||||||
|
|
||||||
|
clientCtx context.Context
|
||||||
|
clientCancel context.CancelFunc
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
|
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
wgProxyFactory *wgproxy.Factory
|
wgProxyFactory *wgproxy.Factory
|
||||||
|
|
||||||
@@ -126,6 +133,8 @@ type Engine struct {
|
|||||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||||
networkSerial uint64
|
networkSerial uint64
|
||||||
|
|
||||||
|
networkWatcher *networkmonitor.NetworkWatcher
|
||||||
|
|
||||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||||
sshServer nbssh.Server
|
sshServer nbssh.Server
|
||||||
|
|
||||||
@@ -151,8 +160,8 @@ type Peer struct {
|
|||||||
|
|
||||||
// NewEngine creates a new Connection Engine
|
// NewEngine creates a new Connection Engine
|
||||||
func NewEngine(
|
func NewEngine(
|
||||||
ctx context.Context,
|
clientCtx context.Context,
|
||||||
cancel context.CancelFunc,
|
clientCancel context.CancelFunc,
|
||||||
signalClient signal.Client,
|
signalClient signal.Client,
|
||||||
mgmClient mgm.Client,
|
mgmClient mgm.Client,
|
||||||
config *EngineConfig,
|
config *EngineConfig,
|
||||||
@@ -160,8 +169,8 @@ func NewEngine(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
return NewEngineWithProbes(
|
return NewEngineWithProbes(
|
||||||
ctx,
|
clientCtx,
|
||||||
cancel,
|
clientCancel,
|
||||||
signalClient,
|
signalClient,
|
||||||
mgmClient,
|
mgmClient,
|
||||||
config,
|
config,
|
||||||
@@ -176,8 +185,8 @@ func NewEngine(
|
|||||||
|
|
||||||
// NewEngineWithProbes creates a new Connection Engine with probes attached
|
// NewEngineWithProbes creates a new Connection Engine with probes attached
|
||||||
func NewEngineWithProbes(
|
func NewEngineWithProbes(
|
||||||
ctx context.Context,
|
clientCtx context.Context,
|
||||||
cancel context.CancelFunc,
|
clientCancel context.CancelFunc,
|
||||||
signalClient signal.Client,
|
signalClient signal.Client,
|
||||||
mgmClient mgm.Client,
|
mgmClient mgm.Client,
|
||||||
config *EngineConfig,
|
config *EngineConfig,
|
||||||
@@ -188,9 +197,10 @@ func NewEngineWithProbes(
|
|||||||
relayProbe *Probe,
|
relayProbe *Probe,
|
||||||
wgProbe *Probe,
|
wgProbe *Probe,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
|
|
||||||
return &Engine{
|
return &Engine{
|
||||||
ctx: ctx,
|
clientCtx: clientCtx,
|
||||||
cancel: cancel,
|
clientCancel: clientCancel,
|
||||||
signal: signalClient,
|
signal: signalClient,
|
||||||
mgmClient: mgmClient,
|
mgmClient: mgmClient,
|
||||||
peerConns: make(map[string]*peer.Conn),
|
peerConns: make(map[string]*peer.Conn),
|
||||||
@@ -202,7 +212,7 @@ func NewEngineWithProbes(
|
|||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
sshServerFunc: nbssh.DefaultSSHServer,
|
sshServerFunc: nbssh.DefaultSSHServer,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
networkWatcher: networkmonitor.New(),
|
||||||
mgmProbe: mgmProbe,
|
mgmProbe: mgmProbe,
|
||||||
signalProbe: signalProbe,
|
signalProbe: signalProbe,
|
||||||
relayProbe: relayProbe,
|
relayProbe: relayProbe,
|
||||||
@@ -214,6 +224,13 @@ func (e *Engine) Stop() error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if e.cancel != nil {
|
||||||
|
e.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopping network monitor first to avoid starting the engine again
|
||||||
|
e.networkWatcher.Stop()
|
||||||
|
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -222,7 +239,7 @@ func (e *Engine) Stop() error {
|
|||||||
e.clientRoutes = nil
|
e.clientRoutes = nil
|
||||||
|
|
||||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||||
// Removing peers happens in the conn.CLose() asynchronously
|
// Removing peers happens in the conn.Close() asynchronously
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
@@ -237,6 +254,13 @@ func (e *Engine) Start() error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if e.cancel != nil {
|
||||||
|
e.cancel()
|
||||||
|
}
|
||||||
|
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||||
|
|
||||||
|
e.wgProxyFactory = wgproxy.NewFactory(e.config.WgPort)
|
||||||
|
|
||||||
wgIface, err := e.newWgIface()
|
wgIface, err := e.newWgIface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
||||||
@@ -320,6 +344,21 @@ func (e *Engine) Start() error {
|
|||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
e.receiveProbeEvents()
|
e.receiveProbeEvents()
|
||||||
|
|
||||||
|
if e.config.NetworkMonitor {
|
||||||
|
// starting network monitor at the very last to avoid disruptions
|
||||||
|
go e.networkWatcher.Start(e.ctx, func() {
|
||||||
|
log.Infof("Network monitor detected network change, restarting engine")
|
||||||
|
if err := e.Stop(); err != nil {
|
||||||
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
|
}
|
||||||
|
if err := e.Start(); err != nil {
|
||||||
|
log.Errorf("Failed to start engine: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -588,12 +627,12 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
func (e *Engine) receiveManagementEvents() {
|
func (e *Engine) receiveManagementEvents() {
|
||||||
go func() {
|
go func() {
|
||||||
err := e.mgmClient.Sync(e.handleSync)
|
err := e.mgmClient.Sync(e.ctx, e.handleSync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
e.cancel()
|
e.clientCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("stopped receiving updates from Management Service")
|
log.Debugf("stopped receiving updates from Management Service")
|
||||||
@@ -736,9 +775,9 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
for _, protoRoute := range protoRoutes {
|
for _, protoRoute := range protoRoutes {
|
||||||
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
||||||
convertedRoute := &route.Route{
|
convertedRoute := &route.Route{
|
||||||
ID: protoRoute.ID,
|
ID: route.ID(protoRoute.ID),
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
NetID: protoRoute.NetID,
|
NetID: route.NetID(protoRoute.NetID),
|
||||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||||
Peer: protoRoute.Peer,
|
Peer: protoRoute.Peer,
|
||||||
Metric: int(protoRoute.Metric),
|
Metric: int(protoRoute.Metric),
|
||||||
@@ -869,11 +908,12 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
|||||||
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||||
e.syncMsgMux.Unlock()
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
err := conn.Open()
|
err := conn.Open(e.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("connection to peer %s failed: %v", peerKey, err)
|
log.Debugf("connection to peer %s failed: %v", peerKey, err)
|
||||||
switch err.(type) {
|
var connectionClosedError *peer.ConnectionClosedError
|
||||||
case *peer.ConnectionClosedError:
|
switch {
|
||||||
|
case errors.As(err, &connectionClosedError):
|
||||||
// conn has been forced to close, so we exit the loop
|
// conn has been forced to close, so we exit the loop
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
@@ -984,7 +1024,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
func (e *Engine) receiveSignalEvents() {
|
func (e *Engine) receiveSignalEvents() {
|
||||||
go func() {
|
go func() {
|
||||||
// connect to a stream of messages coming from the signal server
|
// connect to a stream of messages coming from the signal server
|
||||||
err := e.signal.Receive(func(msg *sProto.Message) error {
|
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -1058,7 +1098,7 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
// happens if signal is unavailable for a long time.
|
// happens if signal is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
e.cancel()
|
e.clientCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -1119,13 +1159,16 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) close() {
|
func (e *Engine) close() {
|
||||||
if err := e.wgProxyFactory.Free(); err != nil {
|
if e.wgProxyFactory != nil {
|
||||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
if err := e.wgProxyFactory.Free(); err != nil {
|
||||||
|
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||||
if e.dnsServer != nil {
|
if e.dnsServer != nil {
|
||||||
e.dnsServer.Stop()
|
e.dnsServer.Stop()
|
||||||
|
e.dnsServer = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.routeManager != nil {
|
if e.routeManager != nil {
|
||||||
@@ -1238,18 +1281,15 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the current routes from the route map
|
// GetClientRoutes returns the current routes from the route map
|
||||||
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
|
func (e *Engine) GetClientRoutes() route.HAMap {
|
||||||
return e.clientRoutes
|
return e.clientRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
|
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
routes := make(map[string][]*route.Route, len(e.clientRoutes))
|
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
||||||
for id, v := range e.clientRoutes {
|
for id, v := range e.clientRoutes {
|
||||||
if i := strings.LastIndex(id, "-"); i != -1 {
|
routes[id.NetID()] = v
|
||||||
id = id[:i]
|
|
||||||
}
|
|
||||||
routes[id] = v
|
|
||||||
}
|
}
|
||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -392,7 +392,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
// feed updates to Engine via mocked Management client
|
// feed updates to Engine via mocked Management client
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
defer close(updates)
|
defer close(updates)
|
||||||
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
for msg := range updates {
|
for msg := range updates {
|
||||||
err := msgHandler(msg)
|
err := msgHandler(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -578,7 +578,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
}{}
|
}{}
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||||
input.inputSerial = updateSerial
|
input.inputSerial = updateSerial
|
||||||
input.inputRoutes = newRoutes
|
input.inputRoutes = newRoutes
|
||||||
return nil, nil, testCase.inputErr
|
return nil, nil, testCase.inputErr
|
||||||
@@ -743,7 +743,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
package internal
|
|
||||||
@@ -68,7 +68,7 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
|||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||||
if isRegistrationNeeded(err) {
|
if serverKey != nil && isRegistrationNeeded(err) {
|
||||||
log.Debugf("peer registration required")
|
log.Debugf("peer registration required")
|
||||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||||
return err
|
return err
|
||||||
|
|||||||
15
client/internal/networkmonitor/monitor.go
Normal file
15
client/internal/networkmonitor/monitor.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NetworkWatcher watches for changes in network configuration.
|
||||||
|
type NetworkWatcher struct {
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new network monitor.
|
||||||
|
func New() *NetworkWatcher {
|
||||||
|
return &NetworkWatcher{}
|
||||||
|
}
|
||||||
133
client/internal/networkmonitor/monitor_bsd.go
Normal file
133
client/internal/networkmonitor/monitor_bsd.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||||
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := unix.Close(fd); err != nil {
|
||||||
|
log.Errorf("Network monitor: failed to close routing socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
buf := make([]byte, 2048)
|
||||||
|
n, err := unix.Read(fd, buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n < unix.SizeofRtMsghdr {
|
||||||
|
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||||
|
|
||||||
|
switch msg.Type {
|
||||||
|
|
||||||
|
// handle interface state changes
|
||||||
|
case unix.RTM_IFINFO:
|
||||||
|
ifinfo, err := parseInterfaceMessage(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Network monitor: error parsing interface message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if msg.Flags&unix.IFF_UP != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
|
||||||
|
callback()
|
||||||
|
|
||||||
|
// handle route changes
|
||||||
|
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||||
|
route, err := parseRouteMessage(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Network monitor: error parsing routing message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !route.Dst.Addr().IsUnspecified() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := "<nil>"
|
||||||
|
if route.Interface != nil {
|
||||||
|
intf = route.Interface.Name
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case unix.RTM_ADD:
|
||||||
|
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||||
|
callback()
|
||||||
|
case unix.RTM_DELETE:
|
||||||
|
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
|
||||||
|
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||||
|
callback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := msgs[0].(*route.InterfaceMessage)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := msgs[0].(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return routemanager.MsgToRoute(msg)
|
||||||
|
}
|
||||||
82
client/internal/networkmonitor/monitor_generic.go
Normal file
82
client/internal/networkmonitor/monitor_generic.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime/debug"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start begins watching for network changes and calls the callback function and stops when a change is detected.
|
||||||
|
func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) {
|
||||||
|
if nw.cancel != nil {
|
||||||
|
log.Warn("Network monitor: already running, stopping previous watcher")
|
||||||
|
nw.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
log.Info("Network monitor: not starting, context is already cancelled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, nw.cancel = context.WithCancel(ctx)
|
||||||
|
defer nw.Stop()
|
||||||
|
|
||||||
|
var nexthop4, nexthop6 netip.Addr
|
||||||
|
var intf4, intf6 *net.Interface
|
||||||
|
|
||||||
|
operation := func() error {
|
||||||
|
var errv4, errv6 error
|
||||||
|
nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
|
||||||
|
if errv4 != nil && errv6 != nil {
|
||||||
|
return errors.New("failed to get default next hops")
|
||||||
|
}
|
||||||
|
|
||||||
|
if errv4 == nil {
|
||||||
|
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name)
|
||||||
|
}
|
||||||
|
if errv6 == nil {
|
||||||
|
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// continue if either route was found
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||||
|
|
||||||
|
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||||
|
log.Errorf("Network monitor: failed to get default next hops: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// recover in case sys ops panic
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Errorf("Network monitor: panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
log.Errorf("Network monitor: failed to start: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the network monitor.
|
||||||
|
func (nw *NetworkWatcher) Stop() {
|
||||||
|
if nw.cancel != nil {
|
||||||
|
nw.cancel()
|
||||||
|
nw.cancel = nil
|
||||||
|
log.Info("Network monitor: stopped")
|
||||||
|
}
|
||||||
|
}
|
||||||
81
client/internal/networkmonitor/monitor_linux.go
Normal file
81
client/internal/networkmonitor/monitor_linux.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||||
|
if intfv4 == nil && intfv6 == nil {
|
||||||
|
return errors.New("no interfaces available")
|
||||||
|
}
|
||||||
|
|
||||||
|
linkChan := make(chan netlink.LinkUpdate)
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||||
|
return fmt.Errorf("subscribe to link updates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
routeChan := make(chan netlink.RouteUpdate)
|
||||||
|
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
|
||||||
|
return fmt.Errorf("subscribe to route updates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Network monitor: started")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
|
||||||
|
// handle interface state changes
|
||||||
|
case update := <-linkChan:
|
||||||
|
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch update.Header.Type {
|
||||||
|
case syscall.RTM_DELLINK:
|
||||||
|
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
case syscall.RTM_NEWLINK:
|
||||||
|
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
|
||||||
|
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle route changes
|
||||||
|
case route := <-routeChan:
|
||||||
|
// default route and main table
|
||||||
|
if route.Dst != nil || route.Table != syscall.RT_TABLE_MAIN {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch route.Type {
|
||||||
|
// triggered on added/replaced routes
|
||||||
|
case syscall.RTM_NEWROUTE:
|
||||||
|
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
case syscall.RTM_DELROUTE:
|
||||||
|
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
|
||||||
|
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
11
client/internal/networkmonitor/monitor_mobile.go
Normal file
11
client/internal/networkmonitor/monitor_mobile.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//go:build ios || android
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
func (nw *NetworkWatcher) Start(context.Context, func()) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nw *NetworkWatcher) Stop() {
|
||||||
|
}
|
||||||
215
client/internal/networkmonitor/monitor_windows.go
Normal file
215
client/internal/networkmonitor/monitor_windows.go
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
unreachable = 0
|
||||||
|
incomplete = 1
|
||||||
|
probe = 2
|
||||||
|
delay = 3
|
||||||
|
stale = 4
|
||||||
|
reachable = 5
|
||||||
|
permanent = 6
|
||||||
|
tbd = 7
|
||||||
|
)
|
||||||
|
|
||||||
|
const interval = 10 * time.Second
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||||
|
var neighborv4, neighborv6 *routemanager.Neighbor
|
||||||
|
{
|
||||||
|
initialNeighbors, err := getNeighbors()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get neighbors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n, ok := initialNeighbors[nexthopv4]; ok {
|
||||||
|
neighborv4 = &n
|
||||||
|
}
|
||||||
|
if n, ok := initialNeighbors[nexthopv6]; ok {
|
||||||
|
neighborv6 = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func changed(
|
||||||
|
nexthopv4 netip.Addr,
|
||||||
|
intfv4 *net.Interface,
|
||||||
|
neighborv4 *routemanager.Neighbor,
|
||||||
|
nexthopv6 netip.Addr,
|
||||||
|
intfv6 *net.Interface,
|
||||||
|
neighborv6 *routemanager.Neighbor,
|
||||||
|
) bool {
|
||||||
|
neighbors, err := getNeighbors()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("network monitor: error fetching current neighbors: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
routes, err := getRoutes()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("network monitor: error fetching current routes: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// routeChanged checks if the default routes still point to our nexthop/interface
|
||||||
|
func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool {
|
||||||
|
if !nexthop.IsValid() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var unspec netip.Prefix
|
||||||
|
if nexthop.Is6() {
|
||||||
|
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
|
} else {
|
||||||
|
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r, ok := routes[unspec]; ok {
|
||||||
|
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 {
|
||||||
|
intf := "<nil>"
|
||||||
|
if r.Interface != nil {
|
||||||
|
intf = r.Interface.Name
|
||||||
|
}
|
||||||
|
log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("network monitor: default route is gone")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool {
|
||||||
|
if neighbor == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
|
||||||
|
if n, ok := neighbors[nexthop]; ok {
|
||||||
|
if n.State != reachable && n.State != permanent {
|
||||||
|
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
|
||||||
|
return true
|
||||||
|
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
|
||||||
|
log.Infof(
|
||||||
|
"network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s",
|
||||||
|
neighbor.IPAddress,
|
||||||
|
neighbor.LinkLayerAddress,
|
||||||
|
neighbor.InterfaceAlias,
|
||||||
|
neighbor.InterfaceIndex,
|
||||||
|
n.InterfaceAlias,
|
||||||
|
n.InterfaceIndex,
|
||||||
|
stateFromInt(n.State),
|
||||||
|
)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
|
||||||
|
entries, err := routemanager.GetNeighbors()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get neighbors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
neighbours[entry.IPAddress] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
return neighbours, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRoutes() (map[netip.Prefix]routemanager.Route, error) {
|
||||||
|
entries, err := routemanager.GetRoutes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := make(map[netip.Prefix]routemanager.Route, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
routes[entry.Destination] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateFromInt(state uint8) string {
|
||||||
|
switch state {
|
||||||
|
case unreachable:
|
||||||
|
return "unreachable"
|
||||||
|
case incomplete:
|
||||||
|
return "incomplete"
|
||||||
|
case probe:
|
||||||
|
return "probe"
|
||||||
|
case delay:
|
||||||
|
return "delay"
|
||||||
|
case stale:
|
||||||
|
return "stale"
|
||||||
|
case reachable:
|
||||||
|
return "reachable"
|
||||||
|
case permanent:
|
||||||
|
return "permanent"
|
||||||
|
case tbd:
|
||||||
|
return "tbd"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareIntf(a, b *net.Interface) int {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if a == nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if b == nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return a.Index - b.Index
|
||||||
|
}
|
||||||
@@ -276,7 +276,7 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
|
|||||||
// Open opens connection to the remote peer starting ICE candidate gathering process.
|
// Open opens connection to the remote peer starting ICE candidate gathering process.
|
||||||
// Blocks until connection has been closed or connection timeout.
|
// Blocks until connection has been closed or connection timeout.
|
||||||
// ConnStatus will be set accordingly
|
// ConnStatus will be set accordingly
|
||||||
func (conn *Conn) Open() error {
|
func (conn *Conn) Open(ctx context.Context) error {
|
||||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
@@ -336,7 +336,7 @@ func (conn *Conn) Open() error {
|
|||||||
// at this point we received offer/answer and we are ready to gather candidates
|
// at this point we received offer/answer and we are ready to gather candidates
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
conn.status = StatusConnecting
|
conn.status = StatusConnecting
|
||||||
conn.ctx, conn.notifyDisconnected = context.WithCancel(context.Background())
|
conn.ctx, conn.notifyDisconnected = context.WithCancel(ctx)
|
||||||
defer conn.notifyDisconnected()
|
defer conn.notifyDisconnected()
|
||||||
conn.mu.Unlock()
|
conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -448,9 +448,11 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if conn.wgProxy != nil {
|
if conn.wgProxy != nil {
|
||||||
_ = conn.wgProxy.CloseConn()
|
if err := conn.wgProxy.CloseConn(); err != nil {
|
||||||
|
log.Warnf("Failed to close turn connection: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, fmt.Errorf("update peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.status = StatusConnected
|
conn.status = StatusConnected
|
||||||
@@ -730,7 +732,7 @@ func (conn *Conn) Close() error {
|
|||||||
// before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
|
// before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
|
||||||
// engine adds a new Conn for 4 and 5
|
// engine adds a new Conn for 4 and 5
|
||||||
// therefore peer 4 has 2 Conn objects
|
// therefore peer 4 has 2 Conn objects
|
||||||
log.Warnf("connection has been already closed or attempted closing not started coonection %s", conn.config.Key)
|
log.Warnf("Connection has been already closed or attempted closing not started connection %s", conn.config.Key)
|
||||||
return NewConnectionAlreadyClosed(conn.config.Key)
|
return NewConnectionAlreadyClosed(conn.config.Key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type clientNetwork struct {
|
|||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
routes map[string]*route.Route
|
routes map[route.ID]*route.Route
|
||||||
routeUpdate chan routesUpdate
|
routeUpdate chan routesUpdate
|
||||||
peerStateUpdate chan struct{}
|
peerStateUpdate chan struct{}
|
||||||
routePeersNotifiers map[string]chan struct{}
|
routePeersNotifiers map[string]chan struct{}
|
||||||
@@ -50,7 +50,7 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st
|
|||||||
stop: cancel,
|
stop: cancel,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
routes: make(map[string]*route.Route),
|
routes: make(map[route.ID]*route.Route),
|
||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan routesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
@@ -59,8 +59,8 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st
|
|||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||||
routePeerStatuses := make(map[string]routerPeerStatus)
|
routePeerStatuses := make(map[route.ID]routerPeerStatus)
|
||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -90,12 +90,12 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
// * Latency: Routes with lower latency are prioritized.
|
// * Latency: Routes with lower latency are prioritized.
|
||||||
//
|
//
|
||||||
// It returns the ID of the selected optimal route.
|
// It returns the ID of the selected optimal route.
|
||||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||||
chosen := ""
|
chosen := route.ID("")
|
||||||
chosenScore := float64(0)
|
chosenScore := float64(0)
|
||||||
currScore := float64(0)
|
currScore := float64(0)
|
||||||
|
|
||||||
currID := ""
|
currID := route.ID("")
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
currID = c.chosenRoute.ID
|
currID = c.chosenRoute.ID
|
||||||
}
|
}
|
||||||
@@ -153,15 +153,16 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
|
|
||||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
||||||
case chosen != currID:
|
case chosen != currID:
|
||||||
if currScore != 0 && currScore < chosenScore+0.1 {
|
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
|
||||||
|
if currScore != 0 && currScore+0.01 > chosenScore {
|
||||||
|
log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||||
return currID
|
return currID
|
||||||
} else {
|
|
||||||
var peer string
|
|
||||||
if route := c.routes[chosen]; route != nil {
|
|
||||||
peer = route.Peer
|
|
||||||
}
|
|
||||||
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, peer, chosenScore, c.network)
|
|
||||||
}
|
}
|
||||||
|
var p string
|
||||||
|
if rt := c.routes[chosen]; rt != nil {
|
||||||
|
p = rt.Peer
|
||||||
|
}
|
||||||
|
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network)
|
||||||
}
|
}
|
||||||
|
|
||||||
return chosen
|
return chosen
|
||||||
@@ -294,7 +295,7 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) handleUpdate(update routesUpdate) {
|
func (c *clientNetwork) handleUpdate(update routesUpdate) {
|
||||||
updateMap := make(map[string]*route.Route)
|
updateMap := make(map[route.ID]*route.Route)
|
||||||
|
|
||||||
for _, r := range update.routes {
|
for _, r := range update.routes {
|
||||||
updateMap[r.ID] = r
|
updateMap[r.ID] = r
|
||||||
|
|||||||
@@ -12,21 +12,21 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
statuses map[string]routerPeerStatus
|
statuses map[route.ID]routerPeerStatus
|
||||||
expectedRouteID string
|
expectedRouteID route.ID
|
||||||
currentRoute string
|
currentRoute route.ID
|
||||||
existingRoutes map[string]*route.Route
|
existingRoutes map[route.ID]*route.Route
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "one route",
|
name: "one route",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
direct: true,
|
direct: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -38,14 +38,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "one connected routes with relayed and direct",
|
name: "one connected routes with relayed and direct",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: true,
|
relayed: true,
|
||||||
direct: true,
|
direct: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -57,14 +57,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "one connected routes with relayed and no direct",
|
name: "one connected routes with relayed and no direct",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: true,
|
relayed: true,
|
||||||
direct: false,
|
direct: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -76,14 +76,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no connected peers",
|
name: "no connected peers",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: false,
|
connected: false,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
direct: false,
|
direct: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -95,7 +95,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple connected peers with different metrics",
|
name: "multiple connected peers with different metrics",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@@ -107,7 +107,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
direct: true,
|
direct: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: 9000,
|
Metric: 9000,
|
||||||
@@ -124,7 +124,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple connected peers with one relayed",
|
name: "multiple connected peers with one relayed",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@@ -136,7 +136,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
direct: true,
|
direct: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -153,7 +153,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple connected peers with one direct",
|
name: "multiple connected peers with one direct",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@@ -165,7 +165,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
direct: false,
|
direct: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -182,7 +182,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple connected peers with different latencies",
|
name: "multiple connected peers with different latencies",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
latency: 300 * time.Millisecond,
|
latency: 300 * time.Millisecond,
|
||||||
@@ -192,7 +192,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -209,7 +209,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should ignore routes with latency 0",
|
name: "should ignore routes with latency 0",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
latency: 0 * time.Millisecond,
|
latency: 0 * time.Millisecond,
|
||||||
@@ -219,7 +219,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -236,12 +236,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "current route with similar score and similar but slightly worse latency should not change",
|
name: "current route with similar score and similar but slightly worse latency should not change",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
direct: true,
|
direct: true,
|
||||||
latency: 12 * time.Millisecond,
|
latency: 15 * time.Millisecond,
|
||||||
},
|
},
|
||||||
"route2": {
|
"route2": {
|
||||||
connected: true,
|
connected: true,
|
||||||
@@ -250,7 +250,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@@ -265,9 +265,40 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
currentRoute: "route1",
|
currentRoute: "route1",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "current route with bad score should be changed to route with better score",
|
||||||
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 200 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route2",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "current chosen route doesn't exist anymore",
|
name: "current chosen route doesn't exist anymore",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@@ -281,7 +312,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
|||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||||
TriggerSelection(map[string][]*route.Route)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
@@ -43,7 +43,7 @@ type DefaultManager struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
clientNetworks map[string]*clientNetwork
|
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||||
routeSelector *routeselector.RouteSelector
|
routeSelector *routeselector.RouteSelector
|
||||||
serverRouter serverRouter
|
serverRouter serverRouter
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
@@ -57,7 +57,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
dm := &DefaultManager{
|
dm := &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
clientNetworks: make(map[string]*clientNetwork),
|
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||||
routeSelector: routeselector.NewRouteSelector(),
|
routeSelector: routeselector.NewRouteSelector(),
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
@@ -122,7 +122,7 @@ func (m *DefaultManager) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not updating routes as context is closed")
|
log.Infof("not updating routes as context is closed")
|
||||||
@@ -155,7 +155,7 @@ func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeL
|
|||||||
|
|
||||||
// InitialRouteRange return the list of initial routes. It used by mobile systems
|
// InitialRouteRange return the list of initial routes. It used by mobile systems
|
||||||
func (m *DefaultManager) InitialRouteRange() []string {
|
func (m *DefaultManager) InitialRouteRange() []string {
|
||||||
return m.notifier.initialRouteRanges()
|
return m.notifier.getInitialRouteRanges()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRouteSelector returns the route selector
|
// GetRouteSelector returns the route selector
|
||||||
@@ -164,16 +164,19 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the client routes
|
// GetClientRoutes returns the client routes
|
||||||
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
|
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
||||||
return m.clientNetworks
|
return m.clientNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||||
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
|
func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
networks = m.routeSelector.FilterSelected(networks)
|
networks = m.routeSelector.FilterSelected(networks)
|
||||||
|
|
||||||
|
m.notifier.onNewRoutes(networks)
|
||||||
|
|
||||||
m.stopObsoleteClients(networks)
|
m.stopObsoleteClients(networks)
|
||||||
|
|
||||||
for id, routes := range networks {
|
for id, routes := range networks {
|
||||||
@@ -190,7 +193,7 @@ func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||||
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
|
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
||||||
for id, client := range m.clientNetworks {
|
for id, client := range m.clientNetworks {
|
||||||
if _, ok := networks[id]; !ok {
|
if _, ok := networks[id]; !ok {
|
||||||
log.Debugf("Stopping client network watcher, %s", id)
|
log.Debugf("Stopping client network watcher, %s", id)
|
||||||
@@ -200,7 +203,7 @@ func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks route.HAMap) {
|
||||||
// removing routes that do not exist as per the update from the Management service.
|
// removing routes that do not exist as per the update from the Management service.
|
||||||
m.stopObsoleteClients(networks)
|
m.stopObsoleteClients(networks)
|
||||||
|
|
||||||
@@ -219,15 +222,15 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
|
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
|
||||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
newClientRoutesIDMap := make(route.HAMap)
|
||||||
newServerRoutesMap := make(map[string]*route.Route)
|
newServerRoutesMap := make(map[route.ID]*route.Route)
|
||||||
ownNetworkIDs := make(map[string]bool)
|
ownNetworkIDs := make(map[route.HAUniqueID]bool)
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
haID := route.GetHAUniqueID(newRoute)
|
||||||
if newRoute.Peer == m.pubKey {
|
if newRoute.Peer == m.pubKey {
|
||||||
ownNetworkIDs[networkID] = true
|
ownNetworkIDs[haID] = true
|
||||||
// only linux is supported for now
|
// only linux is supported for now
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||||
@@ -238,12 +241,12 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*r
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
haID := route.GetHAUniqueID(newRoute)
|
||||||
if !ownNetworkIDs[networkID] {
|
if !ownNetworkIDs[haID] {
|
||||||
if !isPrefixSupported(newRoute.Network) {
|
if !isPrefixSupported(newRoute.Network) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,10 +264,7 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
|||||||
|
|
||||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||||
if !nbnet.CustomRoutingDisabled() {
|
if !nbnet.CustomRoutingDisabled() {
|
||||||
switch runtime.GOOS {
|
return true
|
||||||
case "linux", "windows", "darwin", "ios":
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ import (
|
|||||||
|
|
||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
|
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||||
TriggerSelectionFunc func(map[string][]*route.Route)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
}
|
}
|
||||||
@@ -30,14 +30,14 @@ func (m *MockManager) InitialRouteRange() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||||
if m.UpdateRoutesFunc != nil {
|
if m.UpdateRoutesFunc != nil {
|
||||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||||
}
|
}
|
||||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
|
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
||||||
if m.TriggerSelectionFunc != nil {
|
if m.TriggerSelectionFunc != nil {
|
||||||
m.TriggerSelectionFunc(networks)
|
m.TriggerSelectionFunc(networks)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -10,8 +11,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type notifier struct {
|
type notifier struct {
|
||||||
initialRouteRangers []string
|
initialRouteRanges []string
|
||||||
routeRangers []string
|
routeRanges []string
|
||||||
|
|
||||||
listener listener.NetworkChangeListener
|
listener listener.NetworkChangeListener
|
||||||
listenerMux sync.Mutex
|
listenerMux sync.Mutex
|
||||||
@@ -33,10 +34,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
|
|||||||
nets = append(nets, r.Network.String())
|
nets = append(nets, r.Network.String())
|
||||||
}
|
}
|
||||||
sort.Strings(nets)
|
sort.Strings(nets)
|
||||||
n.initialRouteRangers = nets
|
n.initialRouteRanges = nets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
|
func (n *notifier) onNewRoutes(idMap route.HAMap) {
|
||||||
newNets := make([]string, 0)
|
newNets := make([]string, 0)
|
||||||
for _, routes := range idMap {
|
for _, routes := range idMap {
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
@@ -45,11 +46,18 @@ func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(newNets)
|
sort.Strings(newNets)
|
||||||
if !n.hasDiff(n.initialRouteRangers, newNets) {
|
switch runtime.GOOS {
|
||||||
return
|
case "android":
|
||||||
|
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if !n.hasDiff(n.routeRanges, newNets) {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n.routeRangers = newNets
|
n.routeRanges = newNets
|
||||||
|
|
||||||
n.notify()
|
n.notify()
|
||||||
}
|
}
|
||||||
@@ -62,7 +70,7 @@ func (n *notifier) notify() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func(l listener.NetworkChangeListener) {
|
go func(l listener.NetworkChangeListener) {
|
||||||
l.OnNetworkChanged(strings.Join(n.routeRangers, ","))
|
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
|
||||||
}(n.listener)
|
}(n.listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,6 +86,20 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *notifier) initialRouteRanges() []string {
|
func (n *notifier) getInitialRouteRanges() []string {
|
||||||
return n.initialRouteRangers
|
return addIPv6RangeIfNeeded(n.initialRouteRanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
|
||||||
|
func addIPv6RangeIfNeeded(inputRanges []string) []string {
|
||||||
|
ranges := inputRanges
|
||||||
|
for _, r := range inputRanges {
|
||||||
|
// we are intentionally adding the ipv6 default range in case of ipv4 default range
|
||||||
|
// to ensure that all traffic is managed by the tunnel interface on android
|
||||||
|
if r == "0.0.0.0/0" {
|
||||||
|
ranges = append(ranges, "::/0")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ranges
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package routemanager
|
|||||||
import "github.com/netbirdio/netbird/route"
|
import "github.com/netbirdio/netbird/route"
|
||||||
|
|
||||||
type serverRouter interface {
|
type serverRouter interface {
|
||||||
updateRoutes(map[string]*route.Route) error
|
updateRoutes(map[route.ID]*route.Route) error
|
||||||
removeFromServerNetwork(*route.Route) error
|
removeFromServerNetwork(*route.Route) error
|
||||||
cleanUp()
|
cleanUp()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
type defaultServerRouter struct {
|
type defaultServerRouter struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
routes map[string]*route.Route
|
routes map[route.ID]*route.Route
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
@@ -28,15 +28,15 @@ type defaultServerRouter struct {
|
|||||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
||||||
return &defaultServerRouter{
|
return &defaultServerRouter{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
routes: make(map[string]*route.Route),
|
routes: make(map[route.ID]*route.Route),
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||||
serverRoutesToRemove := make([]string, 0)
|
serverRoutesToRemove := make([]route.ID, 0)
|
||||||
|
|
||||||
for routeID := range m.routes {
|
for routeID := range m.routes {
|
||||||
update, found := routesMap[routeID]
|
update, found := routesMap[routeID]
|
||||||
@@ -168,7 +168,7 @@ func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair,
|
|||||||
return firewall.RouterPair{}, err
|
return firewall.RouterPair{}, err
|
||||||
}
|
}
|
||||||
return firewall.RouterPair{
|
return firewall.RouterPair{
|
||||||
ID: route.ID,
|
ID: string(route.ID),
|
||||||
Source: parsed.String(),
|
Source: parsed.String(),
|
||||||
Destination: route.Network.Masked().String(),
|
Destination: route.Network.Masked().String(),
|
||||||
Masquerade: route.Masquerade,
|
Masquerade: route.Masquerade,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|||||||
addr = netip.IPv6Unspecified()
|
addr = netip.IPv6Unspecified()
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultGateway, _, err := getNextHop(addr)
|
defaultGateway, _, err := GetNextHop(addr)
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
return fmt.Errorf("get existing route gateway: %s", err)
|
return fmt.Errorf("get existing route gateway: %s", err)
|
||||||
}
|
}
|
||||||
@@ -60,7 +60,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
gatewayHop, intf, err := getNextHop(defaultGateway)
|
gatewayHop, intf, err := GetNextHop(defaultGateway)
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||||
}
|
}
|
||||||
@@ -69,14 +69,14 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|||||||
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||||
r, err := netroute.New()
|
r, err := netroute.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||||
}
|
}
|
||||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to get route for %s: %v", ip, err)
|
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||||
return netip.Addr{}, nil, ErrRouteNotFound
|
return netip.Addr{}, nil, ErrRouteNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +163,7 @@ func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||||
nexthop, intf, err := getNextHop(addr)
|
nexthop, intf, err := GetNextHop(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
||||||
}
|
}
|
||||||
@@ -319,11 +319,11 @@ func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||||
}
|
}
|
||||||
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,39 +3,35 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
// selected BSD Route flags.
|
type Route struct {
|
||||||
const (
|
Dst netip.Prefix
|
||||||
RTF_UP = 0x1
|
Gw netip.Addr
|
||||||
RTF_GATEWAY = 0x2
|
Interface *net.Interface
|
||||||
RTF_HOST = 0x4
|
}
|
||||||
RTF_REJECT = 0x8
|
|
||||||
RTF_DYNAMIC = 0x10
|
|
||||||
RTF_MODIFIED = 0x20
|
|
||||||
RTF_STATIC = 0x800
|
|
||||||
RTF_BLACKHOLE = 0x1000
|
|
||||||
RTF_LOCAL = 0x200000
|
|
||||||
RTF_BROADCAST = 0x400000
|
|
||||||
RTF_MULTICAST = 0x800000
|
|
||||||
)
|
|
||||||
|
|
||||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
|
tab, err := retryFetchRIB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("fetch RIB: %v", err)
|
||||||
}
|
}
|
||||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefixList []netip.Prefix
|
var prefixList []netip.Prefix
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
m := msg.(*route.RouteMessage)
|
m := msg.(*route.RouteMessage)
|
||||||
@@ -43,58 +39,121 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
if m.Version < 3 || m.Version > 5 {
|
if m.Version < 3 || m.Version > 5 {
|
||||||
return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version)
|
return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version)
|
||||||
}
|
}
|
||||||
if m.Type != 4 /* RTM_GET */ {
|
if m.Type != syscall.RTM_GET {
|
||||||
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
|
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Flags&RTF_UP == 0 ||
|
if m.Flags&syscall.RTF_UP == 0 ||
|
||||||
m.Flags&(RTF_REJECT|RTF_BLACKHOLE) != 0 {
|
m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.Addrs) < 3 {
|
route, err := MsgToRoute(m)
|
||||||
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
|
if err != nil {
|
||||||
|
log.Warnf("Failed to parse route message: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if route.Dst.IsValid() {
|
||||||
addr, ok := toNetIPAddr(m.Addrs[0])
|
prefixList = append(prefixList, route.Dst)
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
cidr := 32
|
|
||||||
if mask := m.Addrs[2]; mask != nil {
|
|
||||||
cidr, ok = toCIDR(mask)
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
|
||||||
if routePrefix.IsValid() {
|
|
||||||
prefixList = append(prefixList, routePrefix)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
|
func retryFetchRIB() ([]byte, error) {
|
||||||
|
var out []byte
|
||||||
|
operation := func() error {
|
||||||
|
var err error
|
||||||
|
out, err = route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
|
||||||
|
if errors.Is(err, syscall.ENOMEM) {
|
||||||
|
log.Debug("~etrying fetchRIB due to 'cannot allocate memory' error")
|
||||||
|
return err
|
||||||
|
} else if err != nil {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expBackOff := backoff.NewExponentialBackOff()
|
||||||
|
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||||
|
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||||
|
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||||
|
|
||||||
|
err := backoff.Retry(operation, expBackOff)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch routing information: %w", err)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toNetIP(a route.Addr) netip.Addr {
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case *route.Inet4Addr:
|
case *route.Inet4Addr:
|
||||||
return netip.AddrFrom4(t.IP), true
|
return netip.AddrFrom4(t.IP)
|
||||||
|
case *route.Inet6Addr:
|
||||||
|
ip := netip.AddrFrom16(t.IP)
|
||||||
|
if t.ZoneID != 0 {
|
||||||
|
ip.WithZone(strconv.Itoa(t.ZoneID))
|
||||||
|
}
|
||||||
|
return ip
|
||||||
default:
|
default:
|
||||||
return netip.Addr{}, false
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toCIDR(a route.Addr) (int, bool) {
|
func ones(a route.Addr) (int, error) {
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case *route.Inet4Addr:
|
case *route.Inet4Addr:
|
||||||
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
mask, _ := net.IPMask(t.IP[:]).Size()
|
||||||
cidr, _ := mask.Size()
|
return mask, nil
|
||||||
return cidr, true
|
case *route.Inet6Addr:
|
||||||
|
mask, _ := net.IPMask(t.IP[:]).Size()
|
||||||
|
return mask, nil
|
||||||
default:
|
default:
|
||||||
return 0, false
|
return 0, fmt.Errorf("unexpected address type: %T", a)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
|
||||||
|
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]
|
||||||
|
|
||||||
|
addr := toNetIP(dstIP)
|
||||||
|
|
||||||
|
var nexthopAddr netip.Addr
|
||||||
|
var nexthopIntf *net.Interface
|
||||||
|
|
||||||
|
switch t := nexthop.(type) {
|
||||||
|
case *route.Inet4Addr, *route.Inet6Addr:
|
||||||
|
nexthopAddr = toNetIP(t)
|
||||||
|
case *route.LinkAddr:
|
||||||
|
nexthopIntf = &net.Interface{
|
||||||
|
Index: t.Index,
|
||||||
|
Name: t.Name,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected next hop type: %T", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefix netip.Prefix
|
||||||
|
|
||||||
|
if dstMask == nil {
|
||||||
|
if addr.Is4() {
|
||||||
|
prefix = netip.PrefixFrom(addr, 32)
|
||||||
|
} else {
|
||||||
|
prefix = netip.PrefixFrom(addr, 128)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bits, err := ones(dstMask)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse mask: %v", dstMask)
|
||||||
|
}
|
||||||
|
prefix = netip.PrefixFrom(addr, bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Route{
|
||||||
|
Dst: prefix,
|
||||||
|
Gw: nexthopAddr,
|
||||||
|
Interface: nexthopIntf,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
57
client/internal/routemanager/systemops_bsd_test.go
Normal file
57
client/internal/routemanager/systemops_bsd_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBits(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr route.Addr
|
||||||
|
want int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4 all ones",
|
||||||
|
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
|
||||||
|
want: 32,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 normal mask",
|
||||||
|
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
|
||||||
|
want: 24,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 all ones",
|
||||||
|
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
|
||||||
|
want: 128,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 normal mask",
|
||||||
|
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||||
|
want: 64,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unsupported type",
|
||||||
|
addr: &route.LinkAddr{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := ones(tt.addr)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -87,10 +87,10 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
err = removeVPNRoute(testCase.prefix, intf)
|
err = removeVPNRoute(testCase.prefix, intf)
|
||||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||||
|
|
||||||
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
prefixGateway, _, err := GetNextHop(testCase.prefix.Addr())
|
||||||
require.NoError(t, err, "getNextHop should not return err")
|
require.NoError(t, err, "GetNextHop should not return err")
|
||||||
|
|
||||||
internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if testCase.shouldBeRemoved {
|
if testCase.shouldBeRemoved {
|
||||||
@@ -104,7 +104,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetNextHop(t *testing.T) {
|
func TestGetNextHop(t *testing.T) {
|
||||||
gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
}
|
}
|
||||||
@@ -130,7 +130,7 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localIP, _, err := getNextHop(testingPrefix.Addr())
|
localIP, _, err := GetNextHop(testingPrefix.Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error: ", err)
|
t.Fatal("shouldn't return error: ", err)
|
||||||
}
|
}
|
||||||
@@ -146,7 +146,7 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||||
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
t.Log("defaultGateway: ", defaultGateway)
|
t.Log("defaultGateway: ", defaultGateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
@@ -410,8 +410,8 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prefixGateway, _, err := getNextHop(prefix.Addr())
|
prefixGateway, _, err := GetNextHop(prefix.Addr())
|
||||||
require.NoError(t, err, "getNextHop should not return err")
|
require.NoError(t, err, "GetNextHop should not return err")
|
||||||
if invert {
|
if invert {
|
||||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -16,13 +16,41 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/yusufpapurcu/wmi"
|
"github.com/yusufpapurcu/wmi"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Win32_IP4RouteTable struct {
|
type MSFT_NetRoute struct {
|
||||||
Destination string
|
DestinationPrefix string
|
||||||
Mask string
|
NextHop string
|
||||||
|
InterfaceIndex int32
|
||||||
|
InterfaceAlias string
|
||||||
|
AddressFamily uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type Route struct {
|
||||||
|
Destination netip.Prefix
|
||||||
|
Nexthop netip.Addr
|
||||||
|
Interface *net.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
type MSFT_NetNeighbor struct {
|
||||||
|
IPAddress string
|
||||||
|
LinkLayerAddress string
|
||||||
|
State uint8
|
||||||
|
AddressFamily uint16
|
||||||
|
InterfaceIndex uint32
|
||||||
|
InterfaceAlias string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Neighbor struct {
|
||||||
|
IPAddress netip.Addr
|
||||||
|
LinkLayerAddress string
|
||||||
|
State uint8
|
||||||
|
AddressFamily uint16
|
||||||
|
InterfaceIndex uint32
|
||||||
|
InterfaceAlias string
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefixList []netip.Prefix
|
var prefixList []netip.Prefix
|
||||||
@@ -43,44 +71,92 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
mux.Lock()
|
mux.Lock()
|
||||||
defer mux.Unlock()
|
defer mux.Unlock()
|
||||||
|
|
||||||
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
|
||||||
|
|
||||||
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
|
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
|
||||||
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
|
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var routes []Win32_IP4RouteTable
|
routes, err := GetRoutes()
|
||||||
err := wmi.Query(query, &routes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get routes: %w", err)
|
return nil, fmt.Errorf("get routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
prefixList = nil
|
prefixList = nil
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
addr, err := netip.ParseAddr(route.Destination)
|
prefixList = append(prefixList, route.Destination)
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Unable to parse route destination %s: %v", route.Destination, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
maskSlice := net.ParseIP(route.Mask).To4()
|
|
||||||
if maskSlice == nil {
|
|
||||||
log.Warnf("Unable to parse route mask %s", route.Mask)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
|
|
||||||
cidr, _ := mask.Size()
|
|
||||||
|
|
||||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
|
||||||
if routePrefix.IsValid() && routePrefix.Addr().Is4() {
|
|
||||||
prefixList = append(prefixList, routePrefix)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lastUpdate = time.Now()
|
lastUpdate = time.Now()
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRoutes() ([]Route, error) {
|
||||||
|
var entries []MSFT_NetRoute
|
||||||
|
|
||||||
|
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
|
||||||
|
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
|
||||||
|
return nil, fmt.Errorf("get routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var routes []Route
|
||||||
|
for _, entry := range entries {
|
||||||
|
dest, err := netip.ParsePrefix(entry.DestinationPrefix)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to parse route destination %s: %v", entry.DestinationPrefix, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nexthop, err := netip.ParseAddr(entry.NextHop)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to parse route next hop %s: %v", entry.NextHop, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var intf *net.Interface
|
||||||
|
if entry.InterfaceIndex != 0 {
|
||||||
|
intf = &net.Interface{
|
||||||
|
Index: int(entry.InterfaceIndex),
|
||||||
|
Name: entry.InterfaceAlias,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routes = append(routes, Route{
|
||||||
|
Destination: dest,
|
||||||
|
Nexthop: nexthop,
|
||||||
|
Interface: intf,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetNeighbors() ([]Neighbor, error) {
|
||||||
|
var entries []MSFT_NetNeighbor
|
||||||
|
query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor`
|
||||||
|
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var neighbors []Neighbor
|
||||||
|
for _, entry := range entries {
|
||||||
|
addr, err := netip.ParseAddr(entry.IPAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
neighbors = append(neighbors, Neighbor{
|
||||||
|
IPAddress: addr,
|
||||||
|
LinkLayerAddress: entry.LinkLayerAddress,
|
||||||
|
State: entry.State,
|
||||||
|
AddressFamily: entry.AddressFamily,
|
||||||
|
InterfaceIndex: entry.InterfaceIndex,
|
||||||
|
InterfaceAlias: entry.InterfaceAlias,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return neighbors, nil
|
||||||
|
}
|
||||||
|
|
||||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||||
args := []string{"add", prefix.String()}
|
args := []string{"add", prefix.String()}
|
||||||
|
|
||||||
@@ -98,7 +174,9 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
|
|||||||
args = append(args, "if", strconv.Itoa(intf.Index))
|
args = append(args, "if", strconv.Itoa(intf.Index))
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := exec.Command("route", args...).CombinedOutput()
|
routeCmd := uspfilter.GetSystem32Command("route")
|
||||||
|
|
||||||
|
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("route add: %w", err)
|
return fmt.Errorf("route add: %w", err)
|
||||||
@@ -127,7 +205,9 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interf
|
|||||||
args = append(args, nexthop.Unmap().String())
|
args = append(args, nexthop.Unmap().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := exec.Command("route", args...).CombinedOutput()
|
routeCmd := uspfilter.GetSystem32Command("route")
|
||||||
|
|
||||||
|
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -12,22 +12,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RouteSelector struct {
|
type RouteSelector struct {
|
||||||
selectedRoutes map[string]struct{}
|
selectedRoutes map[route.NetID]struct{}
|
||||||
selectAll bool
|
selectAll bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouteSelector() *RouteSelector {
|
func NewRouteSelector() *RouteSelector {
|
||||||
return &RouteSelector{
|
return &RouteSelector{
|
||||||
selectedRoutes: map[string]struct{}{},
|
selectedRoutes: map[route.NetID]struct{}{},
|
||||||
// default selects all routes
|
// default selects all routes
|
||||||
selectAll: true,
|
selectAll: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectRoutes updates the selected routes based on the provided route IDs.
|
// SelectRoutes updates the selected routes based on the provided route IDs.
|
||||||
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
|
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
|
||||||
if !appendRoute {
|
if !appendRoute {
|
||||||
rs.selectedRoutes = map[string]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var multiErr *multierror.Error
|
var multiErr *multierror.Error
|
||||||
@@ -51,15 +51,15 @@ func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRout
|
|||||||
// SelectAllRoutes sets the selector to select all routes.
|
// SelectAllRoutes sets the selector to select all routes.
|
||||||
func (rs *RouteSelector) SelectAllRoutes() {
|
func (rs *RouteSelector) SelectAllRoutes() {
|
||||||
rs.selectAll = true
|
rs.selectAll = true
|
||||||
rs.selectedRoutes = map[string]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeselectRoutes removes specific routes from the selection.
|
// DeselectRoutes removes specific routes from the selection.
|
||||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||||
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
|
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
rs.selectedRoutes = map[string]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
for _, route := range allRoutes {
|
for _, route := range allRoutes {
|
||||||
rs.selectedRoutes[route] = struct{}{}
|
rs.selectedRoutes[route] = struct{}{}
|
||||||
}
|
}
|
||||||
@@ -85,11 +85,11 @@ func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) err
|
|||||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||||
func (rs *RouteSelector) DeselectAllRoutes() {
|
func (rs *RouteSelector) DeselectAllRoutes() {
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
rs.selectedRoutes = map[string]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSelected checks if a specific route is selected.
|
// IsSelected checks if a specific route is selected.
|
||||||
func (rs *RouteSelector) IsSelected(routeID string) bool {
|
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -98,18 +98,14 @@ func (rs *RouteSelector) IsSelected(routeID string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FilterSelected removes unselected routes from the provided map.
|
// FilterSelected removes unselected routes from the provided map.
|
||||||
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
|
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return maps.Clone(routes)
|
return maps.Clone(routes)
|
||||||
}
|
}
|
||||||
|
|
||||||
filtered := map[string][]*route.Route{}
|
filtered := route.HAMap{}
|
||||||
for id, rt := range routes {
|
for id, rt := range routes {
|
||||||
netID := id
|
if rs.IsSelected(id.NetID()) {
|
||||||
if i := strings.LastIndex(id, "-"); i != -1 {
|
|
||||||
netID = id[:i]
|
|
||||||
}
|
|
||||||
if rs.IsSelected(netID) {
|
|
||||||
filtered[id] = rt
|
filtered[id] = rt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,53 +12,53 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRouteSelector_SelectRoutes(t *testing.T) {
|
func TestRouteSelector_SelectRoutes(t *testing.T) {
|
||||||
allRoutes := []string{"route1", "route2", "route3"}
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialSelected []string
|
initialSelected []route.NetID
|
||||||
|
|
||||||
selectRoutes []string
|
selectRoutes []route.NetID
|
||||||
append bool
|
append bool
|
||||||
|
|
||||||
wantSelected []string
|
wantSelected []route.NetID
|
||||||
wantError bool
|
wantError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Select specific routes, initial all selected",
|
name: "Select specific routes, initial all selected",
|
||||||
selectRoutes: []string{"route1", "route2"},
|
selectRoutes: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{"route1", "route2"},
|
wantSelected: []route.NetID{"route1", "route2"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Select specific routes, initial all deselected",
|
name: "Select specific routes, initial all deselected",
|
||||||
initialSelected: []string{},
|
initialSelected: []route.NetID{},
|
||||||
selectRoutes: []string{"route1", "route2"},
|
selectRoutes: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{"route1", "route2"},
|
wantSelected: []route.NetID{"route1", "route2"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Select specific routes with initial selection",
|
name: "Select specific routes with initial selection",
|
||||||
initialSelected: []string{"route1"},
|
initialSelected: []route.NetID{"route1"},
|
||||||
selectRoutes: []string{"route2", "route3"},
|
selectRoutes: []route.NetID{"route2", "route3"},
|
||||||
wantSelected: []string{"route2", "route3"},
|
wantSelected: []route.NetID{"route2", "route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Select non-existing route",
|
name: "Select non-existing route",
|
||||||
selectRoutes: []string{"route1", "route4"},
|
selectRoutes: []route.NetID{"route1", "route4"},
|
||||||
wantSelected: []string{"route1"},
|
wantSelected: []route.NetID{"route1"},
|
||||||
wantError: true,
|
wantError: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Append route with initial selection",
|
name: "Append route with initial selection",
|
||||||
initialSelected: []string{"route1"},
|
initialSelected: []route.NetID{"route1"},
|
||||||
selectRoutes: []string{"route2"},
|
selectRoutes: []route.NetID{"route2"},
|
||||||
append: true,
|
append: true,
|
||||||
wantSelected: []string{"route1", "route2"},
|
wantSelected: []route.NetID{"route1", "route2"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Append route without initial selection",
|
name: "Append route without initial selection",
|
||||||
selectRoutes: []string{"route2"},
|
selectRoutes: []route.NetID{"route2"},
|
||||||
append: true,
|
append: true,
|
||||||
wantSelected: []string{"route2"},
|
wantSelected: []route.NetID{"route2"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,32 +86,32 @@ func TestRouteSelector_SelectRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
|
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
|
||||||
allRoutes := []string{"route1", "route2", "route3"}
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialSelected []string
|
initialSelected []route.NetID
|
||||||
|
|
||||||
wantSelected []string
|
wantSelected []route.NetID
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Initial all selected",
|
name: "Initial all selected",
|
||||||
wantSelected: []string{"route1", "route2", "route3"},
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial all deselected",
|
name: "Initial all deselected",
|
||||||
initialSelected: []string{},
|
initialSelected: []route.NetID{},
|
||||||
wantSelected: []string{"route1", "route2", "route3"},
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial some selected",
|
name: "Initial some selected",
|
||||||
initialSelected: []string{"route1"},
|
initialSelected: []route.NetID{"route1"},
|
||||||
wantSelected: []string{"route1", "route2", "route3"},
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial all selected",
|
name: "Initial all selected",
|
||||||
initialSelected: []string{"route1", "route2", "route3"},
|
initialSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
wantSelected: []string{"route1", "route2", "route3"},
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,39 +134,39 @@ func TestRouteSelector_SelectAllRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteSelector_DeselectRoutes(t *testing.T) {
|
func TestRouteSelector_DeselectRoutes(t *testing.T) {
|
||||||
allRoutes := []string{"route1", "route2", "route3"}
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialSelected []string
|
initialSelected []route.NetID
|
||||||
|
|
||||||
deselectRoutes []string
|
deselectRoutes []route.NetID
|
||||||
|
|
||||||
wantSelected []string
|
wantSelected []route.NetID
|
||||||
wantError bool
|
wantError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Deselect specific routes, initial all selected",
|
name: "Deselect specific routes, initial all selected",
|
||||||
deselectRoutes: []string{"route1", "route2"},
|
deselectRoutes: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{"route3"},
|
wantSelected: []route.NetID{"route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Deselect specific routes, initial all deselected",
|
name: "Deselect specific routes, initial all deselected",
|
||||||
initialSelected: []string{},
|
initialSelected: []route.NetID{},
|
||||||
deselectRoutes: []string{"route1", "route2"},
|
deselectRoutes: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Deselect specific routes with initial selection",
|
name: "Deselect specific routes with initial selection",
|
||||||
initialSelected: []string{"route1", "route2"},
|
initialSelected: []route.NetID{"route1", "route2"},
|
||||||
deselectRoutes: []string{"route1", "route3"},
|
deselectRoutes: []route.NetID{"route1", "route3"},
|
||||||
wantSelected: []string{"route2"},
|
wantSelected: []route.NetID{"route2"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Deselect non-existing route",
|
name: "Deselect non-existing route",
|
||||||
initialSelected: []string{"route1", "route2"},
|
initialSelected: []route.NetID{"route1", "route2"},
|
||||||
deselectRoutes: []string{"route1", "route4"},
|
deselectRoutes: []route.NetID{"route1", "route4"},
|
||||||
wantSelected: []string{"route2"},
|
wantSelected: []route.NetID{"route2"},
|
||||||
wantError: true,
|
wantError: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -195,32 +195,32 @@ func TestRouteSelector_DeselectRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteSelector_DeselectAll(t *testing.T) {
|
func TestRouteSelector_DeselectAll(t *testing.T) {
|
||||||
allRoutes := []string{"route1", "route2", "route3"}
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialSelected []string
|
initialSelected []route.NetID
|
||||||
|
|
||||||
wantSelected []string
|
wantSelected []route.NetID
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Initial all selected",
|
name: "Initial all selected",
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial all deselected",
|
name: "Initial all deselected",
|
||||||
initialSelected: []string{},
|
initialSelected: []route.NetID{},
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial some selected",
|
name: "Initial some selected",
|
||||||
initialSelected: []string{"route1", "route2"},
|
initialSelected: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial all selected",
|
name: "Initial all selected",
|
||||||
initialSelected: []string{"route1", "route2", "route3"},
|
initialSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,7 +245,7 @@ func TestRouteSelector_DeselectAll(t *testing.T) {
|
|||||||
func TestRouteSelector_IsSelected(t *testing.T) {
|
func TestRouteSelector_IsSelected(t *testing.T) {
|
||||||
rs := routeselector.NewRouteSelector()
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
|
err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, rs.IsSelected("route1"))
|
assert.True(t, rs.IsSelected("route1"))
|
||||||
@@ -257,10 +257,10 @@ func TestRouteSelector_IsSelected(t *testing.T) {
|
|||||||
func TestRouteSelector_FilterSelected(t *testing.T) {
|
func TestRouteSelector_FilterSelected(t *testing.T) {
|
||||||
rs := routeselector.NewRouteSelector()
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
|
err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
routes := map[string][]*route.Route{
|
routes := route.HAMap{
|
||||||
"route1-10.0.0.0/8": {},
|
"route1-10.0.0.0/8": {},
|
||||||
"route2-192.168.0.0/16": {},
|
"route2-192.168.0.0/16": {},
|
||||||
"route3-172.16.0.0/12": {},
|
"route3-172.16.0.0/12": {},
|
||||||
@@ -268,7 +268,7 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
|
|||||||
|
|
||||||
filtered := rs.FilterSelected(routes)
|
filtered := rs.FilterSelected(routes)
|
||||||
|
|
||||||
assert.Equal(t, map[string][]*route.Route{
|
assert.Equal(t, route.HAMap{
|
||||||
"route1-10.0.0.0/8": {},
|
"route1-10.0.0.0/8": {},
|
||||||
"route2-192.168.0.0/16": {},
|
"route2-192.168.0.0/16": {},
|
||||||
}, filtered)
|
}, filtered)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package stdnet
|
package stdnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -19,7 +20,7 @@ func InterfaceFilter(disallowList []string) func(string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range disallowList {
|
for _, s := range disallowList {
|
||||||
if strings.HasPrefix(iFace, s) {
|
if strings.HasPrefix(iFace, s) && runtime.GOOS != "ios" {
|
||||||
log.Tracef("ignoring interface %s - it is not allowed", iFace)
|
log.Tracef("ignoring interface %s - it is not allowed", iFace)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ func NewFactory(wgPort int) *Factory {
|
|||||||
f := &Factory{wgPort: wgPort}
|
f := &Factory{wgPort: wgPort}
|
||||||
|
|
||||||
ebpfProxy := NewWGEBPFProxy(wgPort)
|
ebpfProxy := NewWGEBPFProxy(wgPort)
|
||||||
err := ebpfProxy.Listen()
|
err := ebpfProxy.listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||||
return f
|
return f
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
// Proxy is a transfer layer between the Turn connection and the WireGuard
|
// Proxy is a transfer layer between the Turn connection and the WireGuard
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
AddTurnConn(urnConn net.Conn) (net.Addr, error)
|
AddTurnConn(turnConn net.Conn) (net.Addr, error)
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
Free() error
|
Free() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ import (
|
|||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
type WGEBPFProxy struct {
|
type WGEBPFProxy struct {
|
||||||
|
localWGListenPort int
|
||||||
ebpfManager ebpfMgr.Manager
|
ebpfManager ebpfMgr.Manager
|
||||||
lastUsedPort uint16
|
lastUsedPort uint16
|
||||||
localWGListenPort int
|
|
||||||
|
|
||||||
turnConnStore map[uint16]net.Conn
|
turnConnStore map[uint16]net.Conn
|
||||||
turnConnMutex sync.Mutex
|
turnConnMutex sync.Mutex
|
||||||
@@ -45,8 +45,8 @@ func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
|||||||
return wgProxy
|
return wgProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen load ebpf program and listen the proxy
|
// listen load ebpf program and listen the proxy
|
||||||
func (p *WGEBPFProxy) Listen() error {
|
func (p *WGEBPFProxy) listen() error {
|
||||||
pl := portLookup{}
|
pl := portLookup{}
|
||||||
wgPorxyPort, err := pl.searchFreePort()
|
wgPorxyPort, err := pl.searchFreePort()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -72,7 +72,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cErr := p.Free()
|
cErr := p.Free()
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
log.Errorf("failed to close the wgproxy: %s", cErr)
|
log.Errorf("Failed to close the wgproxy: %s", cErr)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -130,6 +130,11 @@ func (p *WGEBPFProxy) Free() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
log.Tracef("stop proxying turn traffic to wg: %d", endpointPort)
|
||||||
|
p.removeTurnConn(endpointPort)
|
||||||
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
n, err := remoteConn.Read(buf)
|
n, err := remoteConn.Read(buf)
|
||||||
@@ -137,12 +142,13 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
|||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||||
}
|
}
|
||||||
p.removeTurnConn(endpointPort)
|
|
||||||
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = p.sendPkg(buf[:n], endpointPort)
|
err = p.sendPkg(buf[:n], endpointPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -186,11 +192,9 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
||||||
log.Tracef("remove turn conn from store by port: %d", turnConnID)
|
|
||||||
p.turnConnMutex.Lock()
|
p.turnConnMutex.Lock()
|
||||||
defer p.turnConnMutex.Unlock()
|
defer p.turnConnMutex.Unlock()
|
||||||
delete(p.turnConnStore, turnConnID)
|
delete(p.turnConnStore, turnConnID)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
||||||
@@ -266,6 +270,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
|||||||
|
|
||||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Errorf("set network layer for checksum: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,8 +278,12 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
|||||||
|
|
||||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Errorf("serialize layers: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost})
|
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
||||||
return err
|
log.Errorf("write to raw conn: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ type WGUserSpaceProxy struct {
|
|||||||
|
|
||||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
||||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||||
log.Debugf("instantiate new userspace proxy")
|
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||||
p := &WGUserSpaceProxy{
|
p := &WGUserSpaceProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
}
|
}
|
||||||
@@ -31,8 +31,8 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
// AddTurnConn start the proxy with the given remote conn
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = turnConn
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
|
|||||||
@@ -2,10 +2,15 @@ package NetBirdSDK
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
@@ -14,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -38,6 +44,12 @@ type CustomLogger interface {
|
|||||||
Error(message string)
|
Error(message string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type selectRoute struct {
|
||||||
|
NetID string
|
||||||
|
Network netip.Prefix
|
||||||
|
Selected bool
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
@@ -55,6 +67,7 @@ type Client struct {
|
|||||||
onHostDnsFn func([]string)
|
onHostDnsFn func([]string)
|
||||||
dnsManager dns.IosDnsManager
|
dnsManager dns.IosDnsManager
|
||||||
loginComplete bool
|
loginComplete bool
|
||||||
|
connectClient *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -107,7 +120,9 @@ func (c *Client) Run(fd int32, interfaceName string) error {
|
|||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.onHostDnsFn = func([]string) {}
|
c.onHostDnsFn = func([]string) {}
|
||||||
cfg.WgIface = interfaceName
|
cfg.WgIface = interfaceName
|
||||||
return internal.RunClientiOS(ctx, cfg, c.recorder, fd, c.networkChangeListener, c.dnsManager)
|
|
||||||
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
|
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -133,10 +148,29 @@ func (c *Client) GetStatusDetails() *StatusDetails {
|
|||||||
|
|
||||||
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
||||||
for n, p := range fullStatus.Peers {
|
for n, p := range fullStatus.Peers {
|
||||||
|
var routes = RoutesDetails{}
|
||||||
|
for r := range p.GetRoutes() {
|
||||||
|
routeInfo := RoutesInfo{r}
|
||||||
|
routes.items = append(routes.items, routeInfo)
|
||||||
|
}
|
||||||
pi := PeerInfo{
|
pi := PeerInfo{
|
||||||
p.IP,
|
IP: p.IP,
|
||||||
p.FQDN,
|
FQDN: p.FQDN,
|
||||||
p.ConnStatus.String(),
|
LocalIceCandidateEndpoint: p.LocalIceCandidateEndpoint,
|
||||||
|
RemoteIceCandidateEndpoint: p.RemoteIceCandidateEndpoint,
|
||||||
|
LocalIceCandidateType: p.LocalIceCandidateType,
|
||||||
|
RemoteIceCandidateType: p.RemoteIceCandidateType,
|
||||||
|
PubKey: p.PubKey,
|
||||||
|
Latency: formatDuration(p.Latency),
|
||||||
|
BytesRx: p.BytesRx,
|
||||||
|
BytesTx: p.BytesTx,
|
||||||
|
ConnStatus: p.ConnStatus.String(),
|
||||||
|
ConnStatusUpdate: p.ConnStatusUpdate.Format("2006-01-02 15:04:05"),
|
||||||
|
Direct: p.Direct,
|
||||||
|
LastWireguardHandshake: p.LastWireguardHandshake.String(),
|
||||||
|
Relayed: p.Relayed,
|
||||||
|
RosenpassEnabled: p.RosenpassEnabled,
|
||||||
|
Routes: routes,
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
}
|
}
|
||||||
@@ -223,3 +257,142 @@ func (c *Client) IsLoginComplete() bool {
|
|||||||
func (c *Client) ClearLoginComplete() {
|
func (c *Client) ClearLoginComplete() {
|
||||||
c.loginComplete = false
|
c.loginComplete = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||||
|
if c.connectClient == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := c.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
routesMap := engine.GetClientRoutesWithNetID()
|
||||||
|
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
||||||
|
|
||||||
|
var routes []*selectRoute
|
||||||
|
for id, rt := range routesMap {
|
||||||
|
if len(rt) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
route := &selectRoute{
|
||||||
|
NetID: string(id),
|
||||||
|
Network: rt[0].Network,
|
||||||
|
Selected: routeSelector.IsSelected(id),
|
||||||
|
}
|
||||||
|
routes = append(routes, route)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(routes, func(i, j int) bool {
|
||||||
|
iPrefix := routes[i].Network.Bits()
|
||||||
|
jPrefix := routes[j].Network.Bits()
|
||||||
|
|
||||||
|
if iPrefix == jPrefix {
|
||||||
|
iAddr := routes[i].Network.Addr()
|
||||||
|
jAddr := routes[j].Network.Addr()
|
||||||
|
if iAddr == jAddr {
|
||||||
|
return routes[i].NetID < routes[j].NetID
|
||||||
|
}
|
||||||
|
return iAddr.String() < jAddr.String()
|
||||||
|
}
|
||||||
|
return iPrefix < jPrefix
|
||||||
|
})
|
||||||
|
|
||||||
|
var routeSelection []RoutesSelectionInfo
|
||||||
|
for _, r := range routes {
|
||||||
|
routeSelection = append(routeSelection, RoutesSelectionInfo{
|
||||||
|
ID: r.NetID,
|
||||||
|
Network: r.Network.String(),
|
||||||
|
Selected: r.Selected,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
|
||||||
|
return &routeSelectionDetails, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SelectRoute(id string) error {
|
||||||
|
if c.connectClient == nil {
|
||||||
|
return fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := c.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
routeManager := engine.GetRouteManager()
|
||||||
|
routeSelector := routeManager.GetRouteSelector()
|
||||||
|
if id == "All" {
|
||||||
|
log.Debugf("select all routes")
|
||||||
|
routeSelector.SelectAllRoutes()
|
||||||
|
} else {
|
||||||
|
log.Debugf("select route with id: %s", id)
|
||||||
|
routes := toNetIDs([]string{id})
|
||||||
|
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
||||||
|
log.Debugf("error when selecting routes: %s", err)
|
||||||
|
return fmt.Errorf("select routes: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
routeManager.TriggerSelection(engine.GetClientRoutes())
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) DeselectRoute(id string) error {
|
||||||
|
if c.connectClient == nil {
|
||||||
|
return fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
engine := c.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
routeManager := engine.GetRouteManager()
|
||||||
|
routeSelector := routeManager.GetRouteSelector()
|
||||||
|
if id == "All" {
|
||||||
|
log.Debugf("deselect all routes")
|
||||||
|
routeSelector.DeselectAllRoutes()
|
||||||
|
} else {
|
||||||
|
log.Debugf("deselect route with id: %s", id)
|
||||||
|
routes := toNetIDs([]string{id})
|
||||||
|
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
||||||
|
log.Debugf("error when deselecting routes: %s", err)
|
||||||
|
return fmt.Errorf("deselect routes: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
routeManager.TriggerSelection(engine.GetClientRoutes())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatDuration(d time.Duration) string {
|
||||||
|
ds := d.String()
|
||||||
|
dotIndex := strings.Index(ds, ".")
|
||||||
|
if dotIndex != -1 {
|
||||||
|
// Determine end of numeric part, ensuring we stop at two decimal places or the actual end if fewer
|
||||||
|
endIndex := dotIndex + 3
|
||||||
|
if endIndex > len(ds) {
|
||||||
|
endIndex = len(ds)
|
||||||
|
}
|
||||||
|
// Find where the numeric part ends by finding the first non-digit character after the dot
|
||||||
|
unitStart := endIndex
|
||||||
|
for unitStart < len(ds) && (ds[unitStart] >= '0' && ds[unitStart] <= '9') {
|
||||||
|
unitStart++
|
||||||
|
}
|
||||||
|
// Ensures that we only take the unit characters after the numerical part
|
||||||
|
if unitStart < len(ds) {
|
||||||
|
return ds[:endIndex] + ds[unitStart:]
|
||||||
|
}
|
||||||
|
return ds[:endIndex] // In case no units are found after the digits
|
||||||
|
}
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
func toNetIDs(routes []string) []route.NetID {
|
||||||
|
var netIDs []route.NetID
|
||||||
|
for _, rt := range routes {
|
||||||
|
netIDs = append(netIDs, route.NetID(rt))
|
||||||
|
}
|
||||||
|
return netIDs
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,9 +2,28 @@ package NetBirdSDK
|
|||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
type PeerInfo struct {
|
type PeerInfo struct {
|
||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus string // Todo replace to enum
|
LocalIceCandidateEndpoint string
|
||||||
|
RemoteIceCandidateEndpoint string
|
||||||
|
LocalIceCandidateType string
|
||||||
|
RemoteIceCandidateType string
|
||||||
|
PubKey string
|
||||||
|
Latency string
|
||||||
|
BytesRx int64
|
||||||
|
BytesTx int64
|
||||||
|
ConnStatus string
|
||||||
|
ConnStatusUpdate string
|
||||||
|
Direct bool
|
||||||
|
LastWireguardHandshake string
|
||||||
|
Relayed bool
|
||||||
|
RosenpassEnabled bool
|
||||||
|
Routes RoutesDetails
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutes return with RouteDetails
|
||||||
|
func (p PeerInfo) GetRouteDetails() *RoutesDetails {
|
||||||
|
return &p.Routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||||
@@ -16,6 +35,21 @@ type PeerInfoCollection interface {
|
|||||||
GetIP() string
|
GetIP() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoutesInfoCollection made for Java layer to get non default types as collection
|
||||||
|
type RoutesInfoCollection interface {
|
||||||
|
Add(s string) RoutesInfoCollection
|
||||||
|
Get(i int) string
|
||||||
|
Size() int
|
||||||
|
}
|
||||||
|
|
||||||
|
type RoutesDetails struct {
|
||||||
|
items []RoutesInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type RoutesInfo struct {
|
||||||
|
Route string
|
||||||
|
}
|
||||||
|
|
||||||
// StatusDetails is the implementation of the PeerInfoCollection
|
// StatusDetails is the implementation of the PeerInfoCollection
|
||||||
type StatusDetails struct {
|
type StatusDetails struct {
|
||||||
items []PeerInfo
|
items []PeerInfo
|
||||||
@@ -23,6 +57,22 @@ type StatusDetails struct {
|
|||||||
ip string
|
ip string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add new PeerInfo to the collection
|
||||||
|
func (array RoutesDetails) Add(s RoutesInfo) RoutesDetails {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get return an element of the collection
|
||||||
|
func (array RoutesDetails) Get(i int) *RoutesInfo {
|
||||||
|
return &array.items[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size return with the size of the collection
|
||||||
|
func (array RoutesDetails) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
||||||
|
|
||||||
// Add new PeerInfo to the collection
|
// Add new PeerInfo to the collection
|
||||||
func (array StatusDetails) Add(s PeerInfo) StatusDetails {
|
func (array StatusDetails) Add(s PeerInfo) StatusDetails {
|
||||||
array.items = append(array.items, s)
|
array.items = append(array.items, s)
|
||||||
|
|||||||
36
client/ios/NetBirdSDK/routes.go
Normal file
36
client/ios/NetBirdSDK/routes.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package NetBirdSDK
|
||||||
|
|
||||||
|
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
|
||||||
|
type RoutesSelectionInfoCollection interface {
|
||||||
|
Add(s string) RoutesSelectionInfoCollection
|
||||||
|
Get(i int) string
|
||||||
|
Size() int
|
||||||
|
}
|
||||||
|
|
||||||
|
type RoutesSelectionDetails struct {
|
||||||
|
All bool
|
||||||
|
Append bool
|
||||||
|
items []RoutesSelectionInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type RoutesSelectionInfo struct {
|
||||||
|
ID string
|
||||||
|
Network string
|
||||||
|
Selected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new PeerInfo to the collection
|
||||||
|
func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get return an element of the collection
|
||||||
|
func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo {
|
||||||
|
return &array.items[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size return with the size of the collection
|
||||||
|
func (array RoutesSelectionDetails) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
||||||
@@ -120,6 +120,7 @@ type LoginRequest struct {
|
|||||||
ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"`
|
ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"`
|
||||||
RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"`
|
RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"`
|
||||||
ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"`
|
ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"`
|
||||||
|
NetworkMonitor *bool `protobuf:"varint,18,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *LoginRequest) Reset() {
|
func (x *LoginRequest) Reset() {
|
||||||
@@ -274,6 +275,13 @@ func (x *LoginRequest) GetExtraIFaceBlacklist() []string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *LoginRequest) GetNetworkMonitor() bool {
|
||||||
|
if x != nil && x.NetworkMonitor != nil {
|
||||||
|
return *x.NetworkMonitor
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type LoginResponse struct {
|
type LoginResponse struct {
|
||||||
state protoimpl.MessageState
|
state protoimpl.MessageState
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
@@ -1893,7 +1901,7 @@ var file_daemon_proto_rawDesc = []byte{
|
|||||||
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
|
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
|
||||||
0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
|
0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
|
||||||
0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74,
|
0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74,
|
||||||
0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x8f, 0x07, 0x0a, 0x0c, 0x4c, 0x6f,
|
0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xcf, 0x07, 0x0a, 0x0c, 0x4c, 0x6f,
|
||||||
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65,
|
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65,
|
||||||
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65,
|
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65,
|
||||||
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61,
|
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61,
|
||||||
@@ -1941,16 +1949,20 @@ var file_daemon_proto_rawDesc = []byte{
|
|||||||
0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63,
|
0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63,
|
||||||
0x65, 0x42, 0x6c, 0x61, 0x63, 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09,
|
0x65, 0x42, 0x6c, 0x61, 0x63, 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09,
|
||||||
0x52, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63,
|
0x52, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63,
|
||||||
0x6b, 0x6c, 0x69, 0x73, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
|
0x6b, 0x6c, 0x69, 0x73, 0x74, 0x12, 0x2b, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
|
||||||
0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69,
|
0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x48, 0x07, 0x52,
|
||||||
0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e,
|
0x0e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x88,
|
||||||
0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17,
|
0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
|
||||||
0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68,
|
0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, 0x65,
|
||||||
0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61,
|
0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x69,
|
||||||
0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13,
|
0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, 0x5f,
|
||||||
0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f,
|
0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65,
|
||||||
0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
|
0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65,
|
||||||
0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d,
|
0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f,
|
||||||
|
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64,
|
||||||
|
0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65,
|
||||||
|
0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, 0x5f, 0x6e, 0x65, 0x74,
|
||||||
|
0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x22, 0xb5, 0x01, 0x0a, 0x0d,
|
||||||
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a,
|
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a,
|
||||||
0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01,
|
0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01,
|
||||||
0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
|
0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
|
||||||
|
|||||||
@@ -87,6 +87,8 @@ message LoginRequest {
|
|||||||
optional bool rosenpassPermissive = 16;
|
optional bool rosenpassPermissive = 16;
|
||||||
|
|
||||||
repeated string extraIFaceBlacklist = 17;
|
repeated string extraIFaceBlacklist = 17;
|
||||||
|
|
||||||
|
optional bool networkMonitor = 18;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
|
|||||||
@@ -9,10 +9,11 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type selectRoute struct {
|
type selectRoute struct {
|
||||||
NetID string
|
NetID route.NetID
|
||||||
Network netip.Prefix
|
Network netip.Prefix
|
||||||
Selected bool
|
Selected bool
|
||||||
}
|
}
|
||||||
@@ -22,12 +23,17 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
if s.engine == nil {
|
if s.connectClient == nil {
|
||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
routesMap := s.engine.GetClientRoutesWithNetID()
|
engine := s.connectClient.Engine()
|
||||||
routeSelector := s.engine.GetRouteManager().GetRouteSelector()
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
routesMap := engine.GetClientRoutesWithNetID()
|
||||||
|
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
||||||
|
|
||||||
var routes []*selectRoute
|
var routes []*selectRoute
|
||||||
for id, rt := range routesMap {
|
for id, rt := range routesMap {
|
||||||
@@ -60,7 +66,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
|
|||||||
var pbRoutes []*proto.Route
|
var pbRoutes []*proto.Route
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
pbRoutes = append(pbRoutes, &proto.Route{
|
pbRoutes = append(pbRoutes, &proto.Route{
|
||||||
ID: route.NetID,
|
ID: string(route.NetID),
|
||||||
Network: route.Network.String(),
|
Network: route.Network.String(),
|
||||||
Selected: route.Selected,
|
Selected: route.Selected,
|
||||||
})
|
})
|
||||||
@@ -76,16 +82,26 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest)
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
routeManager := s.engine.GetRouteManager()
|
if s.connectClient == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := s.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
routeManager := engine.GetRouteManager()
|
||||||
routeSelector := routeManager.GetRouteSelector()
|
routeSelector := routeManager.GetRouteSelector()
|
||||||
if req.GetAll() {
|
if req.GetAll() {
|
||||||
routeSelector.SelectAllRoutes()
|
routeSelector.SelectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
|
routes := toNetIDs(req.GetRouteIDs())
|
||||||
|
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
||||||
return nil, fmt.Errorf("select routes: %w", err)
|
return nil, fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(s.engine.GetClientRoutes())
|
routeManager.TriggerSelection(engine.GetClientRoutes())
|
||||||
|
|
||||||
return &proto.SelectRoutesResponse{}, nil
|
return &proto.SelectRoutesResponse{}, nil
|
||||||
}
|
}
|
||||||
@@ -95,16 +111,34 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
routeManager := s.engine.GetRouteManager()
|
if s.connectClient == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := s.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
routeManager := engine.GetRouteManager()
|
||||||
routeSelector := routeManager.GetRouteSelector()
|
routeSelector := routeManager.GetRouteSelector()
|
||||||
if req.GetAll() {
|
if req.GetAll() {
|
||||||
routeSelector.DeselectAllRoutes()
|
routeSelector.DeselectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
|
routes := toNetIDs(req.GetRouteIDs())
|
||||||
|
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
||||||
return nil, fmt.Errorf("deselect routes: %w", err)
|
return nil, fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(s.engine.GetClientRoutes())
|
routeManager.TriggerSelection(engine.GetClientRoutes())
|
||||||
|
|
||||||
return &proto.SelectRoutesResponse{}, nil
|
return &proto.SelectRoutesResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toNetIDs(routes []string) []route.NetID {
|
||||||
|
var netIDs []route.NetID
|
||||||
|
for _, rt := range routes {
|
||||||
|
netIDs = append(netIDs, route.NetID(rt))
|
||||||
|
}
|
||||||
|
return netIDs
|
||||||
|
}
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ type Server struct {
|
|||||||
config *internal.Config
|
config *internal.Config
|
||||||
proto.UnimplementedDaemonServiceServer
|
proto.UnimplementedDaemonServiceServer
|
||||||
|
|
||||||
engine *internal.Engine
|
connectClient *internal.ConnectClient
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
sessionWatcher *internal.SessionWatcher
|
sessionWatcher *internal.SessionWatcher
|
||||||
@@ -143,11 +143,8 @@ func (s *Server) Start() error {
|
|||||||
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
|
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
|
||||||
}
|
}
|
||||||
|
|
||||||
engineChan := make(chan *internal.Engine, 1)
|
|
||||||
go s.watchEngine(ctx, engineChan)
|
|
||||||
|
|
||||||
if !config.DisableAutoConnect {
|
if !config.DisableAutoConnect {
|
||||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
|
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -158,7 +155,6 @@ func (s *Server) Start() error {
|
|||||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||||
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
|
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
|
||||||
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
|
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
|
||||||
engineChan chan<- *internal.Engine,
|
|
||||||
) {
|
) {
|
||||||
backOff := getConnectWithBackoff(ctx)
|
backOff := getConnectWithBackoff(ctx)
|
||||||
retryStarted := false
|
retryStarted := false
|
||||||
@@ -188,7 +184,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
|
|||||||
|
|
||||||
runOperation := func() error {
|
runOperation := func() error {
|
||||||
log.Tracef("running client connection")
|
log.Tracef("running client connection")
|
||||||
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
|
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
||||||
|
err := s.connectClient.RunWithProbes(mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
||||||
}
|
}
|
||||||
@@ -358,6 +355,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
s.latestConfigInput.WireguardPort = &port
|
s.latestConfigInput.WireguardPort = &port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if msg.NetworkMonitor != nil {
|
||||||
|
inputConfig.NetworkMonitor = msg.NetworkMonitor
|
||||||
|
s.latestConfigInput.NetworkMonitor = msg.NetworkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
if len(msg.ExtraIFaceBlacklist) > 0 {
|
if len(msg.ExtraIFaceBlacklist) > 0 {
|
||||||
inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||||
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||||
@@ -568,10 +570,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
|||||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||||
|
|
||||||
engineChan := make(chan *internal.Engine, 1)
|
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
|
||||||
go s.watchEngine(ctx, engineChan)
|
|
||||||
|
|
||||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
|
|
||||||
|
|
||||||
return &proto.UpResponse{}, nil
|
return &proto.UpResponse{}, nil
|
||||||
}
|
}
|
||||||
@@ -588,8 +587,6 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
|
|||||||
state := internal.CtxGetState(s.rootCtx)
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
state.Set(internal.StatusIdle)
|
state.Set(internal.StatusIdle)
|
||||||
|
|
||||||
s.engine = nil
|
|
||||||
|
|
||||||
return &proto.DownResponse{}, nil
|
return &proto.DownResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -683,22 +680,6 @@ func (s *Server) onSessionExpire() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// watchEngine watches the engine channel and updates the engine state
|
|
||||||
func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) {
|
|
||||||
log.Tracef("Started watching engine")
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
s.engine = nil
|
|
||||||
log.Tracef("Stopped watching engine")
|
|
||||||
return
|
|
||||||
case engine := <-engineChan:
|
|
||||||
log.Tracef("Received engine from watcher")
|
|
||||||
s.engine = engine
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||||
pbFullStatus := proto.FullStatus{
|
pbFullStatus := proto.FullStatus{
|
||||||
ManagementState: &proto.ManagementState{},
|
ManagementState: &proto.ManagementState{},
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
|||||||
t.Setenv(maxRetryTimeVar, "5s")
|
t.Setenv(maxRetryTimeVar, "5s")
|
||||||
t.Setenv(retryMultiplierVar, "1")
|
t.Setenv(retryMultiplierVar, "1")
|
||||||
|
|
||||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil)
|
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
|
||||||
if counter < 3 {
|
if counter < 3 {
|
||||||
t.Fatalf("expected counter > 2, got %d", counter)
|
t.Fatalf("expected counter > 2, got %d", counter)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,4 +4,5 @@ package iface
|
|||||||
type TunAdapter interface {
|
type TunAdapter interface {
|
||||||
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||||
UpdateAddr(address string) error
|
UpdateAddr(address string) error
|
||||||
|
ProtectSocket(fd int32) bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,7 +132,13 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
|
|||||||
|
|
||||||
lines := strings.Split(ipc, "\n")
|
lines := strings.Split(ipc, "\n")
|
||||||
|
|
||||||
output := ""
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
UpdateOnly: true,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
AllowedIPs: []net.IPNet{},
|
||||||
|
}
|
||||||
|
|
||||||
foundPeer := false
|
foundPeer := false
|
||||||
removedAllowedIP := false
|
removedAllowedIP := false
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
@@ -156,19 +162,23 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Append the line to the output string
|
// Append the line to the output string
|
||||||
if strings.HasPrefix(line, "private_key=") || strings.HasPrefix(line, "listen_port=") ||
|
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
||||||
strings.HasPrefix(line, "public_key=") || strings.HasPrefix(line, "preshared_key=") ||
|
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
|
||||||
strings.HasPrefix(line, "endpoint=") || strings.HasPrefix(line, "persistent_keepalive_interval=") ||
|
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||||
strings.HasPrefix(line, "allowed_ip=") {
|
if err != nil {
|
||||||
output += line + "\n"
|
return err
|
||||||
|
}
|
||||||
|
peer.AllowedIPs = append(peer.AllowedIPs, *ipNet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !removedAllowedIP {
|
if !removedAllowedIP {
|
||||||
return fmt.Errorf("allowedIP not found")
|
return fmt.Errorf("allowedIP not found")
|
||||||
} else {
|
|
||||||
return c.device.IpcSet(output)
|
|
||||||
}
|
}
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKey() (*wgtypes.Key, error)
|
GetServerPublicKey() (*wgtypes.Key, error)
|
||||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
mgmt "github.com/netbirdio/netbird/management/server"
|
mgmt "github.com/netbirdio/netbird/management/server"
|
||||||
@@ -255,7 +256,7 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
ch := make(chan *mgmtProto.SyncResponse, 1)
|
ch := make(chan *mgmtProto.SyncResponse, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err = client.Sync(func(msg *mgmtProto.SyncResponse) error {
|
err = client.Sync(context.Background(), func(msg *mgmtProto.SyncResponse) error {
|
||||||
ch <- msg
|
ch <- msg
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -113,8 +113,8 @@ func (c *GrpcClient) ready() bool {
|
|||||||
|
|
||||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||||
// Blocking request. The result will be sent via msgHandler callback function
|
// Blocking request. The result will be sent via msgHandler callback function
|
||||||
func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
func (c *GrpcClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
backOff := defaultBackoff(c.ctx)
|
backOff := defaultBackoff(ctx)
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
log.Debugf("management connection state %v", c.conn.GetState())
|
log.Debugf("management connection state %v", c.conn.GetState())
|
||||||
@@ -123,7 +123,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
|||||||
if connState == connectivity.Shutdown {
|
if connState == connectivity.Shutdown {
|
||||||
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
|
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
|
||||||
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
||||||
c.conn.WaitForStateChange(c.ctx, connState)
|
c.conn.WaitForStateChange(ctx, connState)
|
||||||
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
ctx, cancelStream := context.WithCancel(ctx)
|
||||||
defer cancelStream()
|
defer cancelStream()
|
||||||
stream, err := c.connectToStream(ctx, *serverPubKey)
|
stream, err := c.connectToStream(ctx, *serverPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -276,7 +276,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, fmt.Errorf("failed while getting Management Service public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
serverKey, err := wgtypes.ParseKey(resp.Key)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -9,7 +11,7 @@ import (
|
|||||||
|
|
||||||
type MockClient struct {
|
type MockClient struct {
|
||||||
CloseFunc func() error
|
CloseFunc func() error
|
||||||
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
|
SyncFunc func(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
@@ -28,11 +30,11 @@ func (m *MockClient) Close() error {
|
|||||||
return m.CloseFunc()
|
return m.CloseFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
func (m *MockClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
if m.SyncFunc == nil {
|
if m.SyncFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.SyncFunc(msgHandler)
|
return m.SyncFunc(ctx, msgHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ type AccountManager interface {
|
|||||||
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||||
ListUsers(accountID string) ([]*User, error)
|
ListUsers(accountID string) ([]*User, error)
|
||||||
GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error
|
MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error
|
||||||
DeletePeer(accountID, peerID, userID string) error
|
DeletePeer(accountID, peerID, userID string) error
|
||||||
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||||
@@ -100,10 +100,10 @@ type AccountManager interface {
|
|||||||
SavePolicy(accountID, userID string, policy *Policy) error
|
SavePolicy(accountID, userID string, policy *Policy) error
|
||||||
DeletePolicy(accountID, policyID, userID string) error
|
DeletePolicy(accountID, policyID, userID string) error
|
||||||
ListPolicies(accountID, userID string) ([]*Policy, error)
|
ListPolicies(accountID, userID string) ([]*Policy, error)
|
||||||
GetRoute(accountID, routeID, userID string) (*route.Route, error)
|
GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||||
SaveRoute(accountID, userID string, route *route.Route) error
|
SaveRoute(accountID, userID string, route *route.Route) error
|
||||||
DeleteRoute(accountID, routeID, userID string) error
|
DeleteRoute(accountID string, routeID route.ID, userID string) error
|
||||||
ListRoutes(accountID, userID string) ([]*route.Route, error)
|
ListRoutes(accountID, userID string) ([]*route.Route, error)
|
||||||
GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||||
@@ -117,8 +117,8 @@ type AccountManager interface {
|
|||||||
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
||||||
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||||
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
||||||
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||||
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||||
HasConnectedChannel(peerID string) bool
|
HasConnectedChannel(peerID string) bool
|
||||||
GetExternalCacheManager() ExternalCacheManager
|
GetExternalCacheManager() ExternalCacheManager
|
||||||
@@ -130,6 +130,8 @@ type AccountManager interface {
|
|||||||
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
|
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
|
||||||
GroupValidation(accountId string, groups []string) (bool, error)
|
GroupValidation(accountId string, groups []string) (bool, error)
|
||||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||||
|
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
|
||||||
|
CancelPeerRoutines(peer *nbpeer.Peer) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@@ -229,7 +231,7 @@ type Account struct {
|
|||||||
Groups map[string]*nbgroup.Group `gorm:"-"`
|
Groups map[string]*nbgroup.Group `gorm:"-"`
|
||||||
GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
||||||
Routes map[string]*route.Route `gorm:"-"`
|
Routes map[route.ID]*route.Route `gorm:"-"`
|
||||||
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
|
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
|
||||||
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
@@ -266,7 +268,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
|
|||||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
|
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
|
||||||
peerRoutesMembership := make(lookupMap)
|
peerRoutesMembership := make(lookupMap)
|
||||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||||
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{}
|
peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
groupListMap := a.getPeerGroups(peerID)
|
groupListMap := a.getPeerGroups(peerID)
|
||||||
@@ -284,7 +286,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
|
|||||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
|
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
|
||||||
var filteredRoutes []*route.Route
|
var filteredRoutes []*route.Route
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
_, found := peerMemberships[route.GetHAUniqueID(r)]
|
_, found := peerMemberships[string(route.GetHAUniqueID(r))]
|
||||||
if !found {
|
if !found {
|
||||||
filteredRoutes = append(filteredRoutes, r)
|
filteredRoutes = append(filteredRoutes, r)
|
||||||
}
|
}
|
||||||
@@ -323,7 +325,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
|
|||||||
return enabledRoutes, disabledRoutes
|
return enabledRoutes, disabledRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
seenRoute := make(map[string]struct{})
|
seenRoute := make(map[route.ID]struct{})
|
||||||
|
|
||||||
takeRoute := func(r *route.Route, id string) {
|
takeRoute := func(r *route.Route, id string) {
|
||||||
if _, ok := seenRoute[r.ID]; ok {
|
if _, ok := seenRoute[r.ID]; ok {
|
||||||
@@ -354,7 +356,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
|
|||||||
newPeerRoute := r.Copy()
|
newPeerRoute := r.Copy()
|
||||||
newPeerRoute.Peer = id
|
newPeerRoute.Peer = id
|
||||||
newPeerRoute.PeerGroups = nil
|
newPeerRoute.PeerGroups = nil
|
||||||
newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map
|
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
||||||
takeRoute(newPeerRoute, id)
|
takeRoute(newPeerRoute, id)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -693,7 +695,7 @@ func (a *Account) Copy() *Account {
|
|||||||
policies = append(policies, policy.Copy())
|
policies = append(policies, policy.Copy())
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := map[string]*route.Route{}
|
routes := map[route.ID]*route.Route{}
|
||||||
for id, r := range a.Routes {
|
for id, r := range a.Routes {
|
||||||
routes[id] = r.Copy()
|
routes[id] = r.Copy()
|
||||||
}
|
}
|
||||||
@@ -958,7 +960,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
|
|||||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -1009,7 +1011,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
|
|||||||
|
|
||||||
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
|
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
|
||||||
return func() (time.Duration, bool) {
|
return func() (time.Duration, bool) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -1108,7 +1110,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
|||||||
|
|
||||||
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
||||||
func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
|
func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1567,7 +1569,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err = am.Store.GetAccountByUser(user.Id)
|
account, err = am.Store.GetAccountByUser(user.Id)
|
||||||
@@ -1650,7 +1652,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
unlock := am.Store.AcquireAccountLock(newAcc.Id)
|
unlock := am.Store.AcquireAccountWriteLock(newAcc.Id)
|
||||||
alreadyUnlocked := false
|
alreadyUnlocked := false
|
||||||
defer func() {
|
defer func() {
|
||||||
if !alreadyUnlocked {
|
if !alreadyUnlocked {
|
||||||
@@ -1801,7 +1803,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
|||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(claims.UserId)
|
account, err := am.Store.GetAccountByUser(claims.UserId)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
unlockAccount := am.Store.AcquireAccountLock(account.Id)
|
unlockAccount := am.Store.AcquireAccountWriteLock(account.Id)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
account, err = am.Store.GetAccountByUser(claims.UserId)
|
account, err = am.Store.GetAccountByUser(claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1821,7 +1823,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
|||||||
return account, nil
|
return account, nil
|
||||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
if domainAccount != nil {
|
if domainAccount != nil {
|
||||||
unlockAccount := am.Store.AcquireAccountLock(domainAccount.Id)
|
unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain)
|
domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1835,6 +1837,56 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
|
||||||
|
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireAccountReadLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, mapError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer, netMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||||
|
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.MarkPeerConnected(peer.Key, false, nil, account)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed marking peer as connected %s %v", peer.Key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
||||||
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||||
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
||||||
@@ -1946,7 +1998,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
|||||||
network := NewNetwork()
|
network := NewNetwork()
|
||||||
peers := make(map[string]*nbpeer.Peer)
|
peers := make(map[string]*nbpeer.Peer)
|
||||||
users := make(map[string]*User)
|
users := make(map[string]*User)
|
||||||
routes := make(map[string]*route.Route)
|
routes := make(map[route.ID]*route.Route)
|
||||||
setupKeys := map[string]*SetupKey{}
|
setupKeys := map[string]*SetupKey{}
|
||||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||||
users[userID] = NewOwnerUser(userID)
|
users[userID] = NewOwnerUser(userID)
|
||||||
|
|||||||
@@ -1408,7 +1408,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"route-1": {
|
"route-1": {
|
||||||
ID: "route-1",
|
ID: "route-1",
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
@@ -1437,12 +1437,12 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
|||||||
routes := account.GetRoutesByPrefix(prefix)
|
routes := account.GetRoutesByPrefix(prefix)
|
||||||
|
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeIDs := make(map[string]struct{}, 2)
|
routeIDs := make(map[route.ID]struct{}, 2)
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
routeIDs[r.ID] = struct{}{}
|
routeIDs[r.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
assert.Contains(t, routeIDs, "route-1")
|
assert.Contains(t, routeIDs, route.ID("route-1"))
|
||||||
assert.Contains(t, routeIDs, "route-2")
|
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_GetRoutesToSync(t *testing.T) {
|
func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||||
@@ -1459,7 +1459,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
|||||||
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
||||||
},
|
},
|
||||||
Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"route-1": {
|
"route-1": {
|
||||||
ID: "route-1",
|
ID: "route-1",
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
@@ -1502,12 +1502,12 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
|||||||
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||||
|
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeIDs := make(map[string]struct{}, 2)
|
routeIDs := make(map[route.ID]struct{}, 2)
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
routeIDs[r.ID] = struct{}{}
|
routeIDs[r.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
assert.Contains(t, routeIDs, "route-2")
|
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||||
assert.Contains(t, routeIDs, "route-3")
|
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||||
|
|
||||||
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||||
|
|
||||||
@@ -1573,7 +1573,7 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
SourcePostureChecks: make([]string, 0),
|
SourcePostureChecks: make([]string, 0),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
PeerGroups: []string{},
|
PeerGroups: []string{},
|
||||||
@@ -1655,7 +1655,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
|||||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@@ -1666,7 +1666,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
@@ -1732,8 +1735,10 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
failed := waitTimeout(wg, time.Second)
|
failed := waitTimeout(wg, time.Second)
|
||||||
@@ -1745,7 +1750,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@@ -1756,7 +1761,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func (d DNSSettings) Copy() DNSSettings {
|
|||||||
|
|
||||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||||
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
|
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -57,7 +57,7 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
|
|||||||
|
|
||||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||||
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
// GetEvents returns a list of activity events of an account
|
// GetEvents returns a list of activity events of an account
|
||||||
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
|
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
|||||||
@@ -279,8 +279,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
|
|||||||
return unlock
|
return unlock
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcquireAccountLock acquires account lock and returns a function that releases the lock
|
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
|
||||||
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
|
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||||
log.Debugf("acquiring lock for account %s", accountID)
|
log.Debugf("acquiring lock for account %s", accountID)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||||
@@ -295,6 +295,12 @@ func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
|
|||||||
return unlock
|
return unlock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
|
||||||
|
// This method is still returns a write lock as file store can't handle read locks
|
||||||
|
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||||
|
return s.AcquireAccountWriteLock(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *FileStore) SaveAccount(account *Account) error {
|
func (s *FileStore) SaveAccount(account *Account) error {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@@ -572,6 +578,18 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
|||||||
return account.Copy(), nil
|
return account.Copy(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
||||||
|
if !ok {
|
||||||
|
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetInstallationID returns the installation ID from the store
|
// GetInstallationID returns the installation ID from the store
|
||||||
func (s *FileStore) GetInstallationID() string {
|
func (s *FileStore) GetInstallationID() string {
|
||||||
return s.InstallationID
|
return s.InstallationID
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
|
|||||||
sha256sum: sha256sum,
|
sha256sum: sha256sum,
|
||||||
db: db,
|
db: db,
|
||||||
locationDB: locationDB,
|
locationDB: locationDB,
|
||||||
reloadCheckInterval: 60 * time.Second, // TODO: make configurable
|
reloadCheckInterval: 300 * time.Second, // TODO: make configurable
|
||||||
stopCh: make(chan struct{}),
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,7 +198,7 @@ func (gl *Geolocation) reloader() {
|
|||||||
log.Errorf("mmdb reload failed: %s", err)
|
log.Errorf("mmdb reload failed: %s", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
|
log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
|
||||||
gl.mmdbPath, gl.reloadCheckInterval.Seconds())
|
gl.mmdbPath, gl.reloadCheckInterval.Seconds())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ func (s *SqliteStore) reload() error {
|
|||||||
|
|
||||||
log.Infof("Successfully reloaded '%s'", s.filePath)
|
log.Infof("Successfully reloaded '%s'", s.filePath)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("No changes in '%s', no need to reload", s.filePath)
|
log.Tracef("No changes in '%s', no need to reload", s.filePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func (e *GroupLinkError) Error() string {
|
|||||||
|
|
||||||
// GetGroup object of the peers
|
// GetGroup object of the peers
|
||||||
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
|
|||||||
|
|
||||||
// GetAllGroups returns all groups in an account
|
// GetAllGroups returns all groups in an account
|
||||||
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -76,7 +76,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
|
|||||||
|
|
||||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||||
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -109,7 +109,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
|
|||||||
|
|
||||||
// SaveGroup object of the peers
|
// SaveGroup object of the peers
|
||||||
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
|
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -214,7 +214,7 @@ func difference(a, b []string) []string {
|
|||||||
|
|
||||||
// DeleteGroup object of the peers
|
// DeleteGroup object of the peers
|
||||||
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountId)
|
unlock := am.Store.AcquireAccountWriteLock(accountId)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountId)
|
account, err := am.Store.GetAccount(accountId)
|
||||||
@@ -242,7 +242,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
|||||||
for _, r := range account.Routes {
|
for _, r := range account.Routes {
|
||||||
for _, g := range r.Groups {
|
for _, g := range r.Groups {
|
||||||
if g == groupID {
|
if g == groupID {
|
||||||
return &GroupLinkError{"route", r.NetID}
|
return &GroupLinkError{"route", string(r.NetID)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
|||||||
|
|
||||||
// ListGroups objects of the peers
|
// ListGroups objects of the peers
|
||||||
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -341,7 +341,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
|
|||||||
|
|
||||||
// GroupAddPeer appends peer to the group
|
// GroupAddPeer appends peer to the group
|
||||||
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
|
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
|
|||||||
|
|
||||||
// GroupDeletePeer removes peer from the group
|
// GroupDeletePeer removes peer from the group
|
||||||
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
|
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
|||||||
@@ -134,9 +134,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
|
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mapError(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
||||||
@@ -149,11 +149,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
|
|
||||||
s.ephemeralManager.OnPeerConnected(peer)
|
s.ephemeralManager.OnPeerConnected(peer)
|
||||||
|
|
||||||
err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.config.TURNConfig.TimeBasedCredentials {
|
if s.config.TURNConfig.TimeBasedCredentials {
|
||||||
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
||||||
}
|
}
|
||||||
@@ -207,7 +202,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||||
_ = s.accountManager.MarkPeerConnected(peer.Key, false, nil)
|
_ = s.accountManager.CancelPeerRoutines(peer)
|
||||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -335,6 +335,10 @@ components:
|
|||||||
$ref: '#/components/schemas/CountryCode'
|
$ref: '#/components/schemas/CountryCode'
|
||||||
city_name:
|
city_name:
|
||||||
$ref: '#/components/schemas/CityName'
|
$ref: '#/components/schemas/CityName'
|
||||||
|
serial_number:
|
||||||
|
description: System serial number
|
||||||
|
type: string
|
||||||
|
example: "C02XJ0J0JGH7"
|
||||||
required:
|
required:
|
||||||
- city_name
|
- city_name
|
||||||
- connected
|
- connected
|
||||||
@@ -356,6 +360,7 @@ components:
|
|||||||
- version
|
- version
|
||||||
- ui_version
|
- ui_version
|
||||||
- approval_required
|
- approval_required
|
||||||
|
- serial_number
|
||||||
AccessiblePeer:
|
AccessiblePeer:
|
||||||
allOf:
|
allOf:
|
||||||
- $ref: '#/components/schemas/PeerMinimum'
|
- $ref: '#/components/schemas/PeerMinimum'
|
||||||
|
|||||||
@@ -523,6 +523,9 @@ type Peer struct {
|
|||||||
// Os Peer's operating system and version
|
// Os Peer's operating system and version
|
||||||
Os string `json:"os"`
|
Os string `json:"os"`
|
||||||
|
|
||||||
|
// SerialNumber System serial number
|
||||||
|
SerialNumber string `json:"serial_number"`
|
||||||
|
|
||||||
// SshEnabled Indicates whether SSH server is enabled on this peer
|
// SshEnabled Indicates whether SSH server is enabled on this peer
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
|
|
||||||
@@ -592,6 +595,9 @@ type PeerBase struct {
|
|||||||
// Os Peer's operating system and version
|
// Os Peer's operating system and version
|
||||||
Os string `json:"os"`
|
Os string `json:"os"`
|
||||||
|
|
||||||
|
// SerialNumber System serial number
|
||||||
|
SerialNumber string `json:"serial_number"`
|
||||||
|
|
||||||
// SshEnabled Indicates whether SSH server is enabled on this peer
|
// SshEnabled Indicates whether SSH server is enabled on this peer
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
|
|
||||||
@@ -664,6 +670,9 @@ type PeerBatch struct {
|
|||||||
// Os Peer's operating system and version
|
// Os Peer's operating system and version
|
||||||
Os string `json:"os"`
|
Os string `json:"os"`
|
||||||
|
|
||||||
|
// SerialNumber System serial number
|
||||||
|
SerialNumber string `json:"serial_number"`
|
||||||
|
|
||||||
// SshEnabled Indicates whether SSH server is enabled on this peer
|
// SshEnabled Indicates whether SSH server is enabled on this peer
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
|
|
||||||
|
|||||||
@@ -308,6 +308,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
|||||||
ApprovalRequired: !approved,
|
ApprovalRequired: !approved,
|
||||||
CountryCode: peer.Location.CountryCode,
|
CountryCode: peer.Location.CountryCode,
|
||||||
CityName: peer.Location.CityName,
|
CityName: peer.Location.CityName,
|
||||||
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,6 +341,7 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
|||||||
AccessiblePeersCount: accessiblePeersCount,
|
AccessiblePeersCount: accessiblePeersCount,
|
||||||
CountryCode: peer.Location.CountryCode,
|
CountryCode: peer.Location.CountryCode,
|
||||||
CityName: peer.Location.CityName,
|
CityName: peer.Location.CityName,
|
||||||
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -121,13 +121,14 @@ func TestGetPeers(t *testing.T) {
|
|||||||
Name: "PeerName",
|
Name: "PeerName",
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: false,
|
||||||
Meta: nbpeer.PeerSystemMeta{
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
Hostname: "hostname",
|
Hostname: "hostname",
|
||||||
GoOS: "GoOS",
|
GoOS: "GoOS",
|
||||||
Kernel: "kernel",
|
Kernel: "kernel",
|
||||||
Core: "core",
|
Core: "core",
|
||||||
Platform: "platform",
|
Platform: "platform",
|
||||||
OS: "OS",
|
OS: "OS",
|
||||||
WtVersion: "development",
|
WtVersion: "development",
|
||||||
|
SystemSerialNumber: "C02XJ0J0JGH7",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,6 +246,7 @@ func TestGetPeers(t *testing.T) {
|
|||||||
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
|
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
|
||||||
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
|
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
|
||||||
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
|
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
|
||||||
|
assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
newRoute, err := h.accountManager.CreateRoute(
|
newRoute, err := h.accountManager.CreateRoute(
|
||||||
account.Id, newPrefix.String(), peerId, peerGroupIds,
|
account.Id, newPrefix.String(), peerId, peerGroupIds,
|
||||||
req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
|
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
@@ -135,7 +135,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
_, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
@@ -185,9 +185,9 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRoute := &route.Route{
|
newRoute := &route.Route{
|
||||||
ID: routeID,
|
ID: route.ID(routeID),
|
||||||
Network: newPrefix,
|
Network: newPrefix,
|
||||||
NetID: req.NetworkId,
|
NetID: route.NetID(req.NetworkId),
|
||||||
NetworkType: prefixType,
|
NetworkType: prefixType,
|
||||||
Masquerade: req.Masquerade,
|
Masquerade: req.Masquerade,
|
||||||
Metric: req.Metric,
|
Metric: req.Metric,
|
||||||
@@ -230,7 +230,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteRoute(account.Id, routeID, user.Id)
|
err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
@@ -254,7 +254,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
|
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
|
||||||
return
|
return
|
||||||
@@ -265,9 +265,9 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
func toRouteResponse(serverRoute *route.Route) *api.Route {
|
func toRouteResponse(serverRoute *route.Route) *api.Route {
|
||||||
route := &api.Route{
|
route := &api.Route{
|
||||||
Id: serverRoute.ID,
|
Id: string(serverRoute.ID),
|
||||||
Description: serverRoute.Description,
|
Description: serverRoute.Description,
|
||||||
NetworkId: serverRoute.NetID,
|
NetworkId: string(serverRoute.NetID),
|
||||||
Enabled: serverRoute.Enabled,
|
Enabled: serverRoute.Enabled,
|
||||||
Peer: &serverRoute.Peer,
|
Peer: &serverRoute.Peer,
|
||||||
Network: serverRoute.Network.String(),
|
Network: serverRoute.Network.String(),
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ var testingAccount = &server.Account{
|
|||||||
func initRoutesTestData() *RoutesHandler {
|
func initRoutesTestData() *RoutesHandler {
|
||||||
return &RoutesHandler{
|
return &RoutesHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
|
GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||||
if routeID == existingRouteID {
|
if routeID == existingRouteID {
|
||||||
return baseExistingRoute, nil
|
return baseExistingRoute, nil
|
||||||
}
|
}
|
||||||
@@ -93,7 +93,7 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||||
},
|
},
|
||||||
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
|
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
|
||||||
if peerID == notFoundPeerID {
|
if peerID == notFoundPeerID {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||||
}
|
}
|
||||||
@@ -120,7 +120,7 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
DeleteRouteFunc: func(_ string, routeID string, _ string) error {
|
DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
|
||||||
if routeID != existingRouteID {
|
if routeID != existingRouteID {
|
||||||
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
|
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
|
|||||||
return errors.New("invalid groups")
|
return errors.New("invalid groups")
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
a, err := am.Store.GetAccountByUser(userID)
|
a, err := am.Store.GetAccountByUser(userID)
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
|
|||||||
SourcePostureChecks: []string{"1"},
|
SourcePostureChecks: []string{"1"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"1": {
|
"1": {
|
||||||
ID: "1",
|
ID: "1",
|
||||||
PeerGroups: make([]string, 1),
|
PeerGroups: make([]string, 1),
|
||||||
@@ -151,7 +151,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"1": {
|
"1": {
|
||||||
ID: "1",
|
ID: "1",
|
||||||
PeerGroups: make([]string, 1),
|
PeerGroups: make([]string, 1),
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package migration
|
package migration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -99,3 +101,104 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
|
||||||
|
// T is the type of the model that contains the field to be migrated.
|
||||||
|
func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error {
|
||||||
|
oldColumnName := fieldName
|
||||||
|
newColumnName := fieldName + "_tmp"
|
||||||
|
|
||||||
|
var model T
|
||||||
|
|
||||||
|
if !db.Migrator().HasTable(&model) {
|
||||||
|
log.Printf("Table for %T does not exist, no migration needed", model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := &gorm.Statement{DB: db}
|
||||||
|
err := stmt.Parse(&model)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse model: %w", err)
|
||||||
|
}
|
||||||
|
tableName := stmt.Schema.Table
|
||||||
|
|
||||||
|
var item sql.NullString
|
||||||
|
if err := db.Model(&model).Select(oldColumnName).First(&item).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
log.Printf("No records in table %s, no migration needed", tableName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("fetch first record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.Valid {
|
||||||
|
var js json.RawMessage
|
||||||
|
var syntaxError *json.SyntaxError
|
||||||
|
err = json.Unmarshal([]byte(item.String), &js)
|
||||||
|
if err == nil || !errors.As(err, &syntaxError) {
|
||||||
|
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
|
||||||
|
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rows []map[string]any
|
||||||
|
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
log.Printf("No records in table %s, no migration needed", tableName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("find rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, row := range rows {
|
||||||
|
var blobValue string
|
||||||
|
if columnValue := row[oldColumnName]; columnValue != nil {
|
||||||
|
value, ok := columnValue.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("type assertion failed")
|
||||||
|
}
|
||||||
|
blobValue = value
|
||||||
|
}
|
||||||
|
|
||||||
|
columnIpValue := net.IP(blobValue)
|
||||||
|
if net.ParseIP(columnIpValue.String()) == nil {
|
||||||
|
log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
|
||||||
|
columnIpValue = net.IPv6loopback
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonValue, err := json.Marshal(columnIpValue)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("re-encode to JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
|
||||||
|
return fmt.Errorf("update row: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if indexName != "" {
|
||||||
|
if err := tx.Migrator().DropIndex(&model, indexName); err != nil {
|
||||||
|
return fmt.Errorf("drop index %s: %w", indexName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
|
||||||
|
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
|
||||||
|
}
|
||||||
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
|
||||||
|
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Migration of %s.%s from blob to json completed", tableName, fieldName)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/migration"
|
"github.com/netbirdio/netbird/management/server/migration"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,3 +90,72 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
|
|||||||
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
|
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
|
||||||
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
|
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||||
|
require.NoError(t, err, "Migration should not fail for an empty database")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
|
||||||
|
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
|
||||||
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
|
type location struct {
|
||||||
|
nbpeer.Location
|
||||||
|
ConnectionIP net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
type peer struct {
|
||||||
|
nbpeer.Peer
|
||||||
|
Location location `gorm:"embedded;embeddedPrefix:location_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type account struct {
|
||||||
|
server.Account
|
||||||
|
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Save(&account{
|
||||||
|
Account: server.Account{Id: "123"},
|
||||||
|
Peers: []peer{
|
||||||
|
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||||
|
}},
|
||||||
|
).Error
|
||||||
|
require.NoError(t, err, "Failed to insert blob data")
|
||||||
|
|
||||||
|
var blobValue string
|
||||||
|
err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error
|
||||||
|
assert.NoError(t, err, "Failed to fetch blob data")
|
||||||
|
|
||||||
|
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||||
|
require.NoError(t, err, "Migration should not fail with net.IP blob data")
|
||||||
|
|
||||||
|
var jsonStr string
|
||||||
|
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
||||||
|
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be migrated")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
|
||||||
|
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
|
||||||
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
|
err = db.Save(&server.Account{
|
||||||
|
Id: "1234",
|
||||||
|
PeersG: []nbpeer.Peer{
|
||||||
|
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||||
|
}},
|
||||||
|
).Error
|
||||||
|
require.NoError(t, err, "Failed to insert JSON data")
|
||||||
|
|
||||||
|
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||||
|
require.NoError(t, err, "Migration should not fail with net.IP JSON data")
|
||||||
|
|
||||||
|
var jsonStr string
|
||||||
|
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
||||||
|
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,80 +22,93 @@ type MockAccountManager struct {
|
|||||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
ListUsersFunc func(accountID string) ([]*server.User, error)
|
ListUsersFunc func(accountID string) ([]*server.User, error)
|
||||||
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
||||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||||
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
|
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||||
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||||
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
||||||
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
||||||
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
||||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
||||||
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
||||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||||
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||||
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
||||||
DeletePolicyFunc func(accountID, policyID, userID string) error
|
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
||||||
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
DeletePolicyFunc func(accountID, policyID, userID string) error
|
||||||
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
||||||
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
||||||
MarkPATUsedFunc func(pat string) error
|
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||||
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
MarkPATUsedFunc func(pat string) error
|
||||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
||||||
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||||
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||||
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
DeleteRouteFunc func(accountID, routeID, userID string) error
|
SaveRouteFunc func(accountID string, userID string, route *route.Route) error
|
||||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
|
||||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||||
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||||
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||||
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||||
DeleteAccountFunc func(accountID, userID string) error
|
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||||
GetDNSDomainFunc func() string
|
DeleteAccountFunc func(accountID, userID string) error
|
||||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
GetDNSDomainFunc func() string
|
||||||
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
||||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||||
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||||
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||||
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
|
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||||
HasConnectedChannelFunc func(peerID string) bool
|
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
HasConnectedChannelFunc func(peerID string) bool
|
||||||
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||||
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
||||||
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
||||||
GetIdpManagerFunc func() idp.Manager
|
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
||||||
|
GetIdpManagerFunc func() idp.Manager
|
||||||
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
||||||
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||||
|
if am.SyncAndMarkPeerFunc != nil {
|
||||||
|
return am.SyncAndMarkPeerFunc(peerPubKey, realIP)
|
||||||
|
}
|
||||||
|
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
|
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
|
||||||
approvedPeers := make(map[string]struct{})
|
approvedPeers := make(map[string]struct{})
|
||||||
for id := range account.Peers {
|
for id := range account.Peers {
|
||||||
@@ -180,7 +193,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error {
|
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error {
|
||||||
if am.MarkPeerConnectedFunc != nil {
|
if am.MarkPeerConnectedFunc != nil {
|
||||||
return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
|
return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
|
||||||
}
|
}
|
||||||
@@ -399,15 +412,15 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) CreateRoute(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
||||||
if am.CreateRouteFunc != nil {
|
if am.CreateRouteFunc != nil {
|
||||||
return am.CreateRouteFunc(accountID, network, peerID, peerGroups, description, netID, masquerade, metric, groups, enabled, userID)
|
return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoute mock implementation of GetRoute from server.AccountManager interface
|
// GetRoute mock implementation of GetRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||||
if am.GetRouteFunc != nil {
|
if am.GetRouteFunc != nil {
|
||||||
return am.GetRouteFunc(accountID, routeID, userID)
|
return am.GetRouteFunc(accountID, routeID, userID)
|
||||||
}
|
}
|
||||||
@@ -415,7 +428,7 @@ func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*rout
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface
|
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.Route) error {
|
func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error {
|
||||||
if am.SaveRouteFunc != nil {
|
if am.SaveRouteFunc != nil {
|
||||||
return am.SaveRouteFunc(accountID, userID, route)
|
return am.SaveRouteFunc(accountID, userID, route)
|
||||||
}
|
}
|
||||||
@@ -423,7 +436,7 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
|
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
|
||||||
if am.DeleteRouteFunc != nil {
|
if am.DeleteRouteFunc != nil {
|
||||||
return am.DeleteRouteFunc(accountID, routeID, userID)
|
return am.DeleteRouteFunc(accountID, routeID, userID)
|
||||||
}
|
}
|
||||||
@@ -626,9 +639,9 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||||
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) {
|
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||||
if am.SyncPeerFunc != nil {
|
if am.SyncPeerFunc != nil {
|
||||||
return am.SyncPeerFunc(sync)
|
return am.SyncPeerFunc(sync, account)
|
||||||
}
|
}
|
||||||
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
|||||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||||
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -47,7 +47,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID
|
|||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -94,7 +94,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
|||||||
// SaveNameServerGroup saves nameserver group
|
// SaveNameServerGroup saves nameserver group
|
||||||
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if nsGroupToSave == nil {
|
if nsGroupToSave == nil {
|
||||||
@@ -129,7 +129,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
|
|||||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||||
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -159,7 +159,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
|
|||||||
// ListNameServerGroups returns a list of nameserver groups from account
|
// ListNameServerGroups returns a list of nameserver groups from account
|
||||||
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
|||||||
@@ -88,21 +88,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||||
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error {
|
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
|
||||||
account, err = am.Store.GetAccount(account.Id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -156,7 +142,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
|||||||
|
|
||||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
||||||
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -278,7 +264,7 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
|
|||||||
|
|
||||||
// DeletePeer removes peer from the account by its IP
|
// DeletePeer removes peer from the account by its IP
|
||||||
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
|
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -362,7 +348,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
|||||||
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||||
@@ -381,7 +367,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||||
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountLock (e.g., database is slow)
|
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
|
||||||
// and the peer disconnects with a timeout and tries to register again.
|
// and the peer disconnects with a timeout and tries to register again.
|
||||||
// We just check if this machine has been registered before and reject the second registration.
|
// We just check if this machine has been registered before and reject the second registration.
|
||||||
// The connecting peer should be able to recover with a retry.
|
// The connecting peer should be able to recover with a retry.
|
||||||
@@ -518,25 +504,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) {
|
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
|
||||||
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
|
|
||||||
if err != nil {
|
|
||||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
|
||||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
|
||||||
}
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// we found the peer, and we follow a normal login flow
|
|
||||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
|
||||||
account, err = am.Store.GetAccount(account.Id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||||
@@ -603,7 +571,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
|||||||
}
|
}
|
||||||
|
|
||||||
// we found the peer, and we follow a normal login flow
|
// we found the peer, and we follow a normal login flow
|
||||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||||
@@ -760,7 +728,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||||
@@ -795,7 +763,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
|||||||
|
|
||||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||||
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ type Peer struct {
|
|||||||
// ID is an internal ID of the peer
|
// ID is an internal ID of the peer
|
||||||
ID string `gorm:"primaryKey"`
|
ID string `gorm:"primaryKey"`
|
||||||
// AccountID is a reference to Account that this object belongs
|
// AccountID is a reference to Account that this object belongs
|
||||||
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
|
AccountID string `json:"-" gorm:"index"`
|
||||||
// WireGuard public key
|
// WireGuard public key
|
||||||
Key string `gorm:"index"`
|
Key string `gorm:"index"`
|
||||||
// A setup key this peer was registered with
|
// A setup key this peer was registered with
|
||||||
SetupKey string
|
SetupKey string
|
||||||
// IP address of the Peer
|
// IP address of the Peer
|
||||||
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
|
IP net.IP `gorm:"serializer:json"`
|
||||||
// Meta is a Peer system meta data
|
// Meta is a Peer system meta data
|
||||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
// Name is peer's name (machine name)
|
// Name is peer's name (machine name)
|
||||||
@@ -61,7 +61,7 @@ type PeerStatus struct { //nolint:revive
|
|||||||
|
|
||||||
// Location is a geo location information of a Peer based on public connection IP
|
// Location is a geo location information of a Peer based on public connection IP
|
||||||
type Location struct {
|
type Location struct {
|
||||||
ConnectionIP net.IP // from grpc peer or reverse proxy headers depends on setup
|
ConnectionIP net.IP `gorm:"serializer:json"` // from grpc peer or reverse proxy headers depends on setup
|
||||||
CountryCode string
|
CountryCode string
|
||||||
CityName string
|
CityName string
|
||||||
GeoNameID uint // city level geoname id
|
GeoNameID uint // city level geoname id
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
|
|||||||
|
|
||||||
// GetPolicy from the store
|
// GetPolicy from the store
|
||||||
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
|
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -342,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
|
|||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
|
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
|
|||||||
|
|
||||||
// DeletePolicy from the store
|
// DeletePolicy from the store
|
||||||
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
|
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -397,7 +397,7 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
|
|||||||
|
|
||||||
// ListPolicies from the store
|
// ListPolicies from the store
|
||||||
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
|
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -34,7 +34,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
|
func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
|
func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -113,7 +113,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
|
func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetRoute gets a route object from account and route IDs
|
// GetRoute gets a route object from account and route IDs
|
||||||
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -40,7 +40,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||||
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID, routeID string, peerGroupIDs []string, prefix netip.Prefix) error {
|
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error {
|
||||||
// routes can have both peer and peer_groups
|
// routes can have both peer and peer_groups
|
||||||
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prefixRoute.Peer != "" {
|
if prefixRoute.Peer != "" {
|
||||||
seenPeers[prefixRoute.ID] = true
|
seenPeers[string(prefixRoute.ID)] = true
|
||||||
}
|
}
|
||||||
for _, groupID := range prefixRoute.PeerGroups {
|
for _, groupID := range prefixRoute.PeerGroups {
|
||||||
seenPeerGroups[groupID] = true
|
seenPeerGroups[groupID] = true
|
||||||
@@ -114,8 +114,8 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateRoute creates and saves a new route
|
// CreateRoute creates and saves a new route
|
||||||
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -131,7 +131,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
var newRoute route.Route
|
var newRoute route.Route
|
||||||
newRoute.ID = xid.New().String()
|
newRoute.ID = route.ID(xid.New().String())
|
||||||
|
|
||||||
prefixType, newPrefix, err := route.ParseNetwork(network)
|
prefixType, newPrefix, err := route.ParseNetwork(network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||||
}
|
}
|
||||||
|
|
||||||
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" {
|
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
newRoute.Groups = groups
|
newRoute.Groups = groups
|
||||||
|
|
||||||
if account.Routes == nil {
|
if account.Routes == nil {
|
||||||
account.Routes = make(map[string]*route.Route)
|
account.Routes = make(map[route.ID]*route.Route)
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Routes[newRoute.ID] = &newRoute
|
account.Routes[newRoute.ID] = &newRoute
|
||||||
@@ -187,14 +187,14 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
|
|
||||||
am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
am.StoreEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta())
|
am.StoreEvent(userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||||
|
|
||||||
return &newRoute, nil
|
return &newRoute, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveRoute saves route
|
// SaveRoute saves route
|
||||||
func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error {
|
func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if routeToSave == nil {
|
if routeToSave == nil {
|
||||||
@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
|||||||
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||||
}
|
}
|
||||||
|
|
||||||
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" {
|
if utf8.RuneCountInString(string(routeToSave.NetID)) > route.MaxNetIDChar || routeToSave.NetID == "" {
|
||||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,14 +248,14 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
|||||||
|
|
||||||
am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
am.StoreEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
am.StoreEvent(userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoute deletes route with routeID
|
// DeleteRoute deletes route with routeID
|
||||||
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -274,7 +274,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta())
|
am.StoreEvent(userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||||
|
|
||||||
am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
@@ -283,7 +283,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
|
|||||||
|
|
||||||
// ListRoutes returns a list of routes from account
|
// ListRoutes returns a list of routes from account
|
||||||
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
|
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -310,8 +310,8 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
|
|||||||
|
|
||||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||||
return &proto.Route{
|
return &proto.Route{
|
||||||
ID: route.ID,
|
ID: string(route.ID),
|
||||||
NetID: route.NetID,
|
NetID: string(route.NetID),
|
||||||
Network: route.Network.String(),
|
Network: route.Network.String(),
|
||||||
NetworkType: int64(route.NetworkType),
|
NetworkType: int64(route.NetworkType),
|
||||||
Peer: route.Peer,
|
Peer: route.Peer,
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ const (
|
|||||||
func TestCreateRoute(t *testing.T) {
|
func TestCreateRoute(t *testing.T) {
|
||||||
type input struct {
|
type input struct {
|
||||||
network string
|
network string
|
||||||
netID string
|
netID route.NetID
|
||||||
peerKey string
|
peerKey string
|
||||||
peerGroupIDs []string
|
peerGroupIDs []string
|
||||||
description string
|
description string
|
||||||
@@ -382,8 +382,8 @@ func TestSaveRoute(t *testing.T) {
|
|||||||
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
|
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
|
||||||
validMetric := 1000
|
validMetric := 1000
|
||||||
invalidMetric := 99999
|
invalidMetric := 99999
|
||||||
validNetID := "12345678901234567890qw"
|
validNetID := route.NetID("12345678901234567890qw")
|
||||||
invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1"
|
invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1")
|
||||||
validGroupHA1 := routeGroupHA1
|
validGroupHA1 := routeGroupHA1
|
||||||
validGroupHA2 := routeGroupHA2
|
validGroupHA2 := routeGroupHA2
|
||||||
|
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ func Hash(s string) uint32 {
|
|||||||
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||||
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
|
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
keyDuration := DefaultSetupKeyDuration
|
keyDuration := DefaultSetupKeyDuration
|
||||||
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
|
|||||||
// (e.g. the key itself, creation date, ID, etc).
|
// (e.g. the key itself, creation date, ID, etc).
|
||||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||||
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if keyToSave == nil {
|
if keyToSave == nil {
|
||||||
@@ -327,7 +327,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
|||||||
|
|
||||||
// ListSetupKeys returns a list of all setup keys of the account
|
// ListSetupKeys returns a list of all setup keys of the account
|
||||||
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
|
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -359,7 +359,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
|
|||||||
|
|
||||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||||
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
|
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
|||||||
@@ -127,17 +127,33 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
|
|||||||
return unlock
|
return unlock
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
|
func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||||
log.Tracef("acquiring lock for account %s", accountID)
|
log.Tracef("acquiring write lock for account %s", accountID)
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
||||||
mtx := value.(*sync.Mutex)
|
mtx := value.(*sync.RWMutex)
|
||||||
mtx.Lock()
|
mtx.Lock()
|
||||||
|
|
||||||
unlock = func() {
|
unlock = func() {
|
||||||
mtx.Unlock()
|
mtx.Unlock()
|
||||||
log.Tracef("released lock for account %s in %v", accountID, time.Since(start))
|
log.Tracef("released write lock for account %s in %v", accountID, time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
return unlock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||||
|
log.Tracef("acquiring read lock for account %s", accountID)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
||||||
|
mtx := value.(*sync.RWMutex)
|
||||||
|
mtx.RLock()
|
||||||
|
|
||||||
|
unlock = func() {
|
||||||
|
mtx.RUnlock()
|
||||||
|
log.Tracef("released read lock for account %s in %v", accountID, time.Since(start))
|
||||||
}
|
}
|
||||||
|
|
||||||
return unlock
|
return unlock
|
||||||
@@ -263,36 +279,43 @@ func (s *SqliteStore) GetInstallationID() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||||
var peer nbpeer.Peer
|
var peerCopy nbpeer.Peer
|
||||||
|
peerCopy.Status = &peerStatus
|
||||||
|
result := s.db.Model(&nbpeer.Peer{}).
|
||||||
|
Where("account_id = ? AND id = ?", accountID, peerID).
|
||||||
|
Updates(peerCopy)
|
||||||
|
|
||||||
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
return result.Error
|
||||||
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
|
||||||
}
|
|
||||||
log.Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return status.Errorf(status.Internal, "issue getting peer from store")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.Status = &peerStatus
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
||||||
|
}
|
||||||
|
|
||||||
return s.db.Save(peer).Error
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||||
var peer nbpeer.Peer
|
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||||
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID)
|
var peerCopy nbpeer.Peer
|
||||||
|
// Since the location field has been migrated to JSON serialization,
|
||||||
|
// updating the struct ensures the correct data format is inserted into the database.
|
||||||
|
peerCopy.Location = peerWithLocation.Location
|
||||||
|
|
||||||
|
result := s.db.Model(&nbpeer.Peer{}).
|
||||||
|
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
|
||||||
|
Updates(peerCopy)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
return result.Error
|
||||||
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
|
|
||||||
}
|
|
||||||
log.Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return status.Errorf(status.Internal, "issue getting peer from store")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.Location = peerWithLocation.Location
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
|
||||||
|
}
|
||||||
|
|
||||||
return s.db.Save(peer).Error
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
|
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
|
||||||
@@ -400,6 +423,7 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
||||||
|
|
||||||
var account Account
|
var account Account
|
||||||
result := s.db.Model(&account).
|
result := s.db.Model(&account).
|
||||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||||
@@ -451,7 +475,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
|||||||
}
|
}
|
||||||
account.GroupsG = nil
|
account.GroupsG = nil
|
||||||
|
|
||||||
account.Routes = make(map[string]*route.Route, len(account.RoutesG))
|
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
||||||
for _, route := range account.RoutesG {
|
for _, route := range account.RoutesG {
|
||||||
account.Routes[route.ID] = route.Copy()
|
account.Routes[route.ID] = route.Copy()
|
||||||
}
|
}
|
||||||
@@ -521,6 +545,21 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
|||||||
return s.GetAccount(peer.AccountID)
|
return s.GetAccount(peer.AccountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||||
|
var peer nbpeer.Peer
|
||||||
|
var accountID string
|
||||||
|
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
log.Errorf("error when getting peer from the store: %s", result.Error)
|
||||||
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||||
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||||
var user User
|
var user User
|
||||||
@@ -571,13 +610,17 @@ func getMigrations() []migrationFunc {
|
|||||||
func(db *gorm.DB) error {
|
func(db *gorm.DB) error {
|
||||||
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
|
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
|
||||||
},
|
},
|
||||||
|
|
||||||
func(db *gorm.DB) error {
|
func(db *gorm.DB) error {
|
||||||
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
|
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
|
||||||
},
|
},
|
||||||
|
|
||||||
func(db *gorm.DB) error {
|
func(db *gorm.DB) error {
|
||||||
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
|
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
|
||||||
},
|
},
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||||
|
},
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -12,6 +10,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -75,9 +76,9 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
|
|||||||
}
|
}
|
||||||
account.Users[user.Id] = user
|
account.Users[user.Id] = user
|
||||||
route := &route2.Route{
|
route := &route2.Route{
|
||||||
ID: fmt.Sprintf("network-id-%d", n),
|
ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
|
||||||
Description: "base route",
|
Description: "base route",
|
||||||
NetID: fmt.Sprintf("network-id-%d", n),
|
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
|
||||||
Network: netip.MustParsePrefix(netIP.String() + "/24"),
|
Network: netip.MustParsePrefix(netIP.String() + "/24"),
|
||||||
NetworkType: route2.IPv4Network,
|
NetworkType: route2.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
@@ -518,15 +519,29 @@ func TestMigrate(t *testing.T) {
|
|||||||
Net net.IPNet `gorm:"serializer:gob"`
|
Net net.IPNet `gorm:"serializer:gob"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type location struct {
|
||||||
|
nbpeer.Location
|
||||||
|
ConnectionIP net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
type peer struct {
|
||||||
|
nbpeer.Peer
|
||||||
|
Location location `gorm:"embedded;embeddedPrefix:location_"`
|
||||||
|
}
|
||||||
|
|
||||||
type account struct {
|
type account struct {
|
||||||
Account
|
Account
|
||||||
Network *network `gorm:"embedded;embeddedPrefix:network_"`
|
Network *network `gorm:"embedded;embeddedPrefix:network_"`
|
||||||
|
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
act := &account{
|
act := &account{
|
||||||
Network: &network{
|
Network: &network{
|
||||||
Net: *ipnet,
|
Net: *ipnet,
|
||||||
},
|
},
|
||||||
|
Peers: []peer{
|
||||||
|
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = store.db.Save(act).Error
|
err = store.db.Save(act).Error
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type Store interface {
|
|||||||
DeleteAccount(account *Account) error
|
DeleteAccount(account *Account) error
|
||||||
GetAccountByUser(userID string) (*Account, error)
|
GetAccountByUser(userID string) (*Account, error)
|
||||||
GetAccountByPeerPubKey(peerKey string) (*Account, error)
|
GetAccountByPeerPubKey(peerKey string) (*Account, error)
|
||||||
|
GetAccountIDByPeerPubKey(peerKey string) (string, error)
|
||||||
GetAccountByPeerID(peerID string) (*Account, error)
|
GetAccountByPeerID(peerID string) (*Account, error)
|
||||||
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
|
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
|
||||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||||
@@ -29,8 +30,10 @@ type Store interface {
|
|||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
SaveInstallationID(ID string) error
|
SaveInstallationID(ID string) error
|
||||||
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
|
// AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
|
||||||
AcquireAccountLock(accountID string) func()
|
AcquireAccountWriteLock(accountID string) func()
|
||||||
|
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
|
||||||
|
AcquireAccountReadLock(accountID string) func()
|
||||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||||
AcquireGlobalLock() func()
|
AcquireGlobalLock() func()
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ func NewOwnerUser(id string) *User {
|
|||||||
|
|
||||||
// createServiceUser creates a new service user under the given account.
|
// createServiceUser creates a new service user under the given account.
|
||||||
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -266,7 +266,7 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *User
|
|||||||
|
|
||||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||||
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if am.idpManager == nil {
|
if am.idpManager == nil {
|
||||||
@@ -367,7 +367,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
|
|||||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err = am.Store.GetAccount(account.Id)
|
account, err = am.Store.GetAccount(account.Id)
|
||||||
@@ -400,7 +400,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
|
|||||||
// ListUsers returns lists of all users under the account.
|
// ListUsers returns lists of all users under the account.
|
||||||
// It doesn't populate user information such as email or name.
|
// It doesn't populate user information such as email or name.
|
||||||
func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) {
|
func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
|
|||||||
if initiatorUserID == targetUserID {
|
if initiatorUserID == targetUserID {
|
||||||
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
|
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
|
||||||
}
|
}
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -537,7 +537,7 @@ func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetU
|
|||||||
|
|
||||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||||
func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if am.idpManager == nil {
|
if am.idpManager == nil {
|
||||||
@@ -577,7 +577,7 @@ func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID st
|
|||||||
|
|
||||||
// CreatePAT creates a new PAT for the given user
|
// CreatePAT creates a new PAT for the given user
|
||||||
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if tokenName == "" {
|
if tokenName == "" {
|
||||||
@@ -627,7 +627,7 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str
|
|||||||
|
|
||||||
// DeletePAT deletes a specific PAT from a user
|
// DeletePAT deletes a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -677,7 +677,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str
|
|||||||
|
|
||||||
// GetPAT returns a specific PAT from a user
|
// GetPAT returns a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -709,7 +709,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string
|
|||||||
|
|
||||||
// GetAllPATs returns all PATs for a user
|
// GetAllPATs returns all PATs for a user
|
||||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
@@ -747,7 +747,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
|
|||||||
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
|
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
|
||||||
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||||
func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
|
func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if update == nil {
|
if update == nil {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ RestartSec=5
|
|||||||
TimeoutStopSec=10
|
TimeoutStopSec=10
|
||||||
CacheDirectory=netbird
|
CacheDirectory=netbird
|
||||||
ConfigurationDirectory=netbird
|
ConfigurationDirectory=netbird
|
||||||
LogDirectory=netbird
|
LogsDirectory=netbird
|
||||||
RuntimeDirectory=netbird
|
RuntimeDirectory=netbird
|
||||||
StateDirectory=netbird
|
StateDirectory=netbird
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ RestartSec=5
|
|||||||
TimeoutStopSec=10
|
TimeoutStopSec=10
|
||||||
CacheDirectory=netbird
|
CacheDirectory=netbird
|
||||||
ConfigurationDirectory=netbird
|
ConfigurationDirectory=netbird
|
||||||
LogDirectory=netbird
|
LogsDirectory=netbird
|
||||||
RuntimeDirectory=netbird
|
RuntimeDirectory=netbird
|
||||||
StateDirectory=netbird
|
StateDirectory=netbird
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ RestartSec=5
|
|||||||
TimeoutStopSec=10
|
TimeoutStopSec=10
|
||||||
CacheDirectory=netbird
|
CacheDirectory=netbird
|
||||||
ConfigurationDirectory=netbird
|
ConfigurationDirectory=netbird
|
||||||
LogDirectory=netbird
|
LogsDirectory=netbird
|
||||||
RuntimeDirectory=netbird
|
RuntimeDirectory=netbird
|
||||||
StateDirectory=netbird
|
StateDirectory=netbird
|
||||||
|
|
||||||
@@ -28,7 +28,8 @@ ProtectControlGroups=yes
|
|||||||
ProtectHome=yes
|
ProtectHome=yes
|
||||||
ProtectHostname=yes
|
ProtectHostname=yes
|
||||||
ProtectKernelLogs=yes
|
ProtectKernelLogs=yes
|
||||||
ProtectKernelModules=no # needed to load wg module for kernel-mode WireGuard
|
# needed to load wg module for kernel-mode WireGuard
|
||||||
|
ProtectKernelModules=no
|
||||||
ProtectKernelTunables=no
|
ProtectKernelTunables=no
|
||||||
ProtectSystem=yes
|
ProtectSystem=yes
|
||||||
RemoveIPC=yes
|
RemoveIPC=yes
|
||||||
|
|||||||
22
route/hauniqueid.go
Normal file
22
route/hauniqueid.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package route
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
type HAUniqueID string
|
||||||
|
|
||||||
|
// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address
|
||||||
|
func GetHAUniqueID(input *Route) HAUniqueID {
|
||||||
|
return HAUniqueID(string(input.NetID) + "-" + input.Network.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (id HAUniqueID) String() string {
|
||||||
|
return string(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetID returns the Network ID from the HAUniqueID
|
||||||
|
func (id HAUniqueID) NetID() NetID {
|
||||||
|
if i := strings.LastIndex(string(id), "-"); i != -1 {
|
||||||
|
return NetID(id[:i])
|
||||||
|
}
|
||||||
|
return NetID(id)
|
||||||
|
}
|
||||||
@@ -36,6 +36,12 @@ const (
|
|||||||
IPv6Network
|
IPv6Network
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ID string
|
||||||
|
|
||||||
|
type NetID string
|
||||||
|
|
||||||
|
type HAMap map[HAUniqueID][]*Route
|
||||||
|
|
||||||
// NetworkType route network type
|
// NetworkType route network type
|
||||||
type NetworkType int
|
type NetworkType int
|
||||||
|
|
||||||
@@ -65,11 +71,11 @@ func ToPrefixType(prefix string) NetworkType {
|
|||||||
|
|
||||||
// Route represents a route
|
// Route represents a route
|
||||||
type Route struct {
|
type Route struct {
|
||||||
ID string `gorm:"primaryKey"`
|
ID ID `gorm:"primaryKey"`
|
||||||
// AccountID is a reference to Account that this object belongs
|
// AccountID is a reference to Account that this object belongs
|
||||||
AccountID string `gorm:"index"`
|
AccountID string `gorm:"index"`
|
||||||
Network netip.Prefix `gorm:"serializer:json"`
|
Network netip.Prefix `gorm:"serializer:json"`
|
||||||
NetID string
|
NetID NetID
|
||||||
Description string
|
Description string
|
||||||
Peer string
|
Peer string
|
||||||
PeerGroups []string `gorm:"serializer:json"`
|
PeerGroups []string `gorm:"serializer:json"`
|
||||||
@@ -165,8 +171,3 @@ func compareList(list, other []string) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address
|
|
||||||
func GetHAUniqueID(input *Route) string {
|
|
||||||
return input.NetID + "-" + input.Network.String()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -33,7 +34,7 @@ type Client interface {
|
|||||||
io.Closer
|
io.Closer
|
||||||
StreamConnected() bool
|
StreamConnected() bool
|
||||||
GetStatus() Status
|
GetStatus() Status
|
||||||
Receive(msgHandler func(msg *proto.Message) error) error
|
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
|
||||||
Ready() bool
|
Ready() bool
|
||||||
IsHealthy() bool
|
IsHealthy() bool
|
||||||
WaitStreamConnected()
|
WaitStreamConnected()
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
keyA, _ := wgtypes.GenerateKey()
|
keyA, _ := wgtypes.GenerateKey()
|
||||||
clientA := createSignalClient(addr, keyA)
|
clientA := createSignalClient(addr, keyA)
|
||||||
go func() {
|
go func() {
|
||||||
err := clientA.Receive(func(msg *sigProto.Message) error {
|
err := clientA.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||||
payloadReceivedOnA = msg.GetBody().GetPayload()
|
payloadReceivedOnA = msg.GetBody().GetPayload()
|
||||||
featuresSupportedReceivedOnA = msg.GetBody().GetFeaturesSupported()
|
featuresSupportedReceivedOnA = msg.GetBody().GetFeaturesSupported()
|
||||||
msgReceived.Done()
|
msgReceived.Done()
|
||||||
@@ -72,7 +72,7 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
clientB := createSignalClient(addr, keyB)
|
clientB := createSignalClient(addr, keyB)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := clientB.Receive(func(msg *sigProto.Message) error {
|
err := clientB.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||||
payloadReceivedOnB = msg.GetBody().GetPayload()
|
payloadReceivedOnB = msg.GetBody().GetPayload()
|
||||||
featuresSupportedReceivedOnB = msg.GetBody().GetFeaturesSupported()
|
featuresSupportedReceivedOnB = msg.GetBody().GetFeaturesSupported()
|
||||||
err := clientB.Send(&sigProto.Message{
|
err := clientB.Send(&sigProto.Message{
|
||||||
@@ -122,7 +122,7 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
key, _ := wgtypes.GenerateKey()
|
key, _ := wgtypes.GenerateKey()
|
||||||
client := createSignalClient(addr, key)
|
client := createSignalClient(addr, key)
|
||||||
go func() {
|
go func() {
|
||||||
err := client.Receive(func(msg *sigProto.Message) error {
|
err := client.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -126,9 +126,9 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
|
|||||||
// The messages will be handled by msgHandler function provided.
|
// The messages will be handled by msgHandler function provided.
|
||||||
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
||||||
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
||||||
func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error {
|
||||||
|
|
||||||
var backOff = defaultBackoff(c.ctx)
|
var backOff = defaultBackoff(ctx)
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
|
|
||||||
@@ -139,13 +139,13 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
|||||||
if connState == connectivity.Shutdown {
|
if connState == connectivity.Shutdown {
|
||||||
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
|
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
|
||||||
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
||||||
c.signalConn.WaitForStateChange(c.ctx, connState)
|
c.signalConn.WaitForStateChange(ctx, connState)
|
||||||
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
|
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to Signal stream identifying ourselves with a public WireGuard key
|
// connect to Signal stream identifying ourselves with a public WireGuard key
|
||||||
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
||||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
ctx, cancelStream := context.WithCancel(ctx)
|
||||||
defer cancelStream()
|
defer cancelStream()
|
||||||
stream, err := c.connect(ctx, c.key.PublicKey().String())
|
stream, err := c.connect(ctx, c.key.PublicKey().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user