mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-20 06:49:55 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0c3818e06 | ||
|
|
6922826919 | ||
|
|
56a1a75e3f | ||
|
|
d9402168ad | ||
|
|
dbdef04b9e | ||
|
|
29cbfe8467 | ||
|
|
6ce8643368 | ||
|
|
07d1ad35fc | ||
|
|
ef6cd36f1a | ||
|
|
c1c71b6d39 | ||
|
|
0480507a10 | ||
|
|
34ac4e4b5a | ||
|
|
52ff9d9602 | ||
|
|
1b73fae46e | ||
|
|
d897365abc | ||
|
|
f37aa2cc9d | ||
|
|
5343bee7b2 | ||
|
|
870e29db63 | ||
|
|
08e9b05d51 | ||
|
|
3581648071 | ||
|
|
2a51609436 | ||
|
|
83457f8b99 | ||
|
|
b45284f086 | ||
|
|
e9016aecea | ||
|
|
23b5d45b68 | ||
|
|
0e5dc9d412 | ||
|
|
91f7ee6a3c | ||
|
|
7c6b85b4cb |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.18"
|
||||
SIGN_PIPE_VER: "v0.0.20"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -134,6 +134,7 @@ jobs:
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
|
||||
run: |
|
||||
set -x
|
||||
@@ -180,6 +181,7 @@ jobs:
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
FROM alpine:3.21.3
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
FROM alpine:3.21.0
|
||||
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||
|
||||
RUN apk add --no-cache ca-certificates \
|
||||
&& adduser -D -h /var/lib/netbird netbird
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// Preferences export a subset of the internal config for gomobile
|
||||
// Preferences exports a subset of the internal config for gomobile
|
||||
type Preferences struct {
|
||||
configInput internal.ConfigInput
|
||||
}
|
||||
|
||||
// NewPreferences create new Preferences instance
|
||||
// NewPreferences creates a new Preferences instance
|
||||
func NewPreferences(configPath string) *Preferences {
|
||||
ci := internal.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
|
||||
return &Preferences{ci}
|
||||
}
|
||||
|
||||
// GetManagementURL read url from config file
|
||||
// GetManagementURL reads URL from config file
|
||||
func (p *Preferences) GetManagementURL() (string, error) {
|
||||
if p.configInput.ManagementURL != "" {
|
||||
return p.configInput.ManagementURL, nil
|
||||
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
|
||||
return cfg.ManagementURL.String(), err
|
||||
}
|
||||
|
||||
// SetManagementURL store the given url and wait for commit
|
||||
// SetManagementURL stores the given URL and waits for commit
|
||||
func (p *Preferences) SetManagementURL(url string) {
|
||||
p.configInput.ManagementURL = url
|
||||
}
|
||||
|
||||
// GetAdminURL read url from config file
|
||||
// GetAdminURL reads URL from config file
|
||||
func (p *Preferences) GetAdminURL() (string, error) {
|
||||
if p.configInput.AdminURL != "" {
|
||||
return p.configInput.AdminURL, nil
|
||||
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
|
||||
return cfg.AdminURL.String(), err
|
||||
}
|
||||
|
||||
// SetAdminURL store the given url and wait for commit
|
||||
// SetAdminURL stores the given URL and waits for commit
|
||||
func (p *Preferences) SetAdminURL(url string) {
|
||||
p.configInput.AdminURL = url
|
||||
}
|
||||
|
||||
// GetPreSharedKey read preshared key from config file
|
||||
// GetPreSharedKey reads pre-shared key from config file
|
||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
if p.configInput.PreSharedKey != nil {
|
||||
return *p.configInput.PreSharedKey, nil
|
||||
@@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
return cfg.PreSharedKey, err
|
||||
}
|
||||
|
||||
// SetPreSharedKey store the given key and wait for commit
|
||||
// SetPreSharedKey stores the given key and waits for commit
|
||||
func (p *Preferences) SetPreSharedKey(key string) {
|
||||
p.configInput.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// SetRosenpassEnabled store if rosenpass is enabled
|
||||
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
||||
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||
p.configInput.RosenpassEnabled = &enabled
|
||||
}
|
||||
|
||||
// GetRosenpassEnabled read rosenpass enabled from config file
|
||||
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
||||
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
if p.configInput.RosenpassEnabled != nil {
|
||||
return *p.configInput.RosenpassEnabled, nil
|
||||
@@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
return cfg.RosenpassEnabled, err
|
||||
}
|
||||
|
||||
// SetRosenpassPermissive store the given permissive and wait for commit
|
||||
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
||||
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||
p.configInput.RosenpassPermissive = &permissive
|
||||
}
|
||||
|
||||
// GetRosenpassPermissive read rosenpass permissive from config file
|
||||
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
||||
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
if p.configInput.RosenpassPermissive != nil {
|
||||
return *p.configInput.RosenpassPermissive, nil
|
||||
@@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
return cfg.RosenpassPermissive, err
|
||||
}
|
||||
|
||||
// Commit write out the changes into config file
|
||||
// GetDisableClientRoutes reads disable client routes setting from config file
|
||||
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
||||
if p.configInput.DisableClientRoutes != nil {
|
||||
return *p.configInput.DisableClientRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableClientRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableClientRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
||||
p.configInput.DisableClientRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableServerRoutes reads disable server routes setting from config file
|
||||
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
||||
if p.configInput.DisableServerRoutes != nil {
|
||||
return *p.configInput.DisableServerRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableServerRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableServerRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
||||
p.configInput.DisableServerRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableDNS reads disable DNS setting from config file
|
||||
func (p *Preferences) GetDisableDNS() (bool, error) {
|
||||
if p.configInput.DisableDNS != nil {
|
||||
return *p.configInput.DisableDNS, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableDNS, err
|
||||
}
|
||||
|
||||
// SetDisableDNS stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableDNS(disable bool) {
|
||||
p.configInput.DisableDNS = &disable
|
||||
}
|
||||
|
||||
// GetDisableFirewall reads disable firewall setting from config file
|
||||
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
||||
if p.configInput.DisableFirewall != nil {
|
||||
return *p.configInput.DisableFirewall, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableFirewall, err
|
||||
}
|
||||
|
||||
// SetDisableFirewall stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableFirewall(disable bool) {
|
||||
p.configInput.DisableFirewall = &disable
|
||||
}
|
||||
|
||||
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
||||
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
||||
if p.configInput.ServerSSHAllowed != nil {
|
||||
return *p.configInput.ServerSSHAllowed, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.ServerSSHAllowed == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.ServerSSHAllowed, err
|
||||
}
|
||||
|
||||
// SetServerSSHAllowed stores the given value and waits for commit
|
||||
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
return *p.configInput.BlockInbound, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.BlockInbound, err
|
||||
}
|
||||
|
||||
// SetBlockInbound stores the given value and waits for commit
|
||||
func (p *Preferences) SetBlockInbound(block bool) {
|
||||
p.configInput.BlockInbound = &block
|
||||
}
|
||||
|
||||
// Commit writes out the changes to the config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||
return err
|
||||
|
||||
210
client/cmd/flags.go
Normal file
210
client/cmd/flags.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// SharedFlags contains all configuration flags that are common between up and set commands
|
||||
type SharedFlags struct {
|
||||
// Network configuration
|
||||
InterfaceName string
|
||||
WireguardPort uint16
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress string
|
||||
ExtraIFaceBlackList []string
|
||||
DNSLabels []string
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
// Feature flags
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
AutoConnectDisabled bool
|
||||
NetworkMonitor bool
|
||||
LazyConnEnabled bool
|
||||
|
||||
// System flags
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
// Login-specific (only for up command)
|
||||
NoBrowser bool
|
||||
}
|
||||
|
||||
// AddSharedFlags adds all shared configuration flags to the given command
|
||||
func AddSharedFlags(cmd *cobra.Command, flags *SharedFlags) {
|
||||
// Network configuration flags
|
||||
cmd.PersistentFlags().StringVar(&flags.InterfaceName, interfaceNameFlag, iface.WgInterfaceDefault,
|
||||
"Wireguard interface name")
|
||||
cmd.PersistentFlags().Uint16Var(&flags.WireguardPort, wireguardPortFlag, iface.DefaultWgPort,
|
||||
"Wireguard interface listening port")
|
||||
cmd.PersistentFlags().StringSliceVar(&flags.NATExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces. `+
|
||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
|
||||
`or --external-ip-map ""`)
|
||||
cmd.PersistentFlags().StringVar(&flags.CustomDNSAddress, dnsResolverAddress, "",
|
||||
`Sets a custom address for NetBird's local DNS resolver. `+
|
||||
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`)
|
||||
cmd.PersistentFlags().StringSliceVar(&flags.ExtraIFaceBlackList, extraIFaceBlackListFlag, nil,
|
||||
"Extra list of default interfaces to ignore for listening")
|
||||
cmd.PersistentFlags().StringSliceVar(&flags.DNSLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels. `+
|
||||
`You can specify a comma-separated list of up to 32 labels. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`)
|
||||
cmd.PersistentFlags().DurationVar(&flags.DNSRouteInterval, dnsRouteIntervalFlag, time.Minute,
|
||||
"DNS route update interval")
|
||||
|
||||
// Feature flags
|
||||
cmd.PersistentFlags().BoolVar(&flags.RosenpassEnabled, enableRosenpassFlag, false,
|
||||
"[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.RosenpassPermissive, rosenpassPermissiveFlag, false,
|
||||
"[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.ServerSSHAllowed, serverSSHAllowedFlag, false,
|
||||
"Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
cmd.PersistentFlags().BoolVar(&flags.AutoConnectDisabled, disableAutoConnectFlag, false,
|
||||
"Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
cmd.PersistentFlags().BoolVarP(&flags.NetworkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`)
|
||||
cmd.PersistentFlags().BoolVar(&flags.LazyConnEnabled, enableLazyConnectionFlag, false,
|
||||
"[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
// System flags (from system.go)
|
||||
cmd.PersistentFlags().BoolVar(&flags.DisableClientRoutes, disableClientRoutesFlag, false,
|
||||
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.DisableServerRoutes, disableServerRoutesFlag, false,
|
||||
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.DisableDNS, disableDNSFlag, false,
|
||||
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.DisableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.BlockLANAccess, blockLANAccessFlag, false,
|
||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
cmd.PersistentFlags().BoolVar(&flags.BlockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
// AddUpOnlyFlags adds flags that are specific to the up command
|
||||
func AddUpOnlyFlags(cmd *cobra.Command, flags *SharedFlags) {
|
||||
cmd.PersistentFlags().BoolVar(&flags.NoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
}
|
||||
|
||||
// BuildConfigInput creates an internal.ConfigInput from SharedFlags with Changed() checks
|
||||
func BuildConfigInput(cmd *cobra.Command, flags *SharedFlags, customDNSAddressConverted []byte) (*internal.ConfigInput, error) {
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
}
|
||||
|
||||
// Handle PreSharedKey from root command
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &flags.RosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &flags.RosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &flags.ServerSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(flags.InterfaceName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ic.InterfaceName = &flags.InterfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(flags.WireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &flags.NetworkMonitor
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &flags.AutoConnectDisabled
|
||||
|
||||
if flags.AutoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
} else {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &flags.DNSRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &flags.DisableClientRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &flags.DisableServerRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &flags.DisableDNS
|
||||
}
|
||||
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &flags.DisableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &flags.BlockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
ic.BlockInbound = &flags.BlockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &flags.LazyConnEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(externalIPMapFlag).Changed {
|
||||
ic.NATExternalIPs = flags.NATExternalIPs
|
||||
}
|
||||
|
||||
if cmd.Flag(extraIFaceBlackListFlag).Changed {
|
||||
ic.ExtraIFaceBlackList = flags.ExtraIFaceBlackList
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||
var err error
|
||||
ic.DNSLabels, err = domain.FromStringList(flags.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid DNS labels: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &ic, nil
|
||||
}
|
||||
@@ -149,6 +149,7 @@ func init() {
|
||||
rootCmd.AddCommand(loginCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(setCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
@@ -167,24 +168,6 @@ func init() {
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces.`+
|
||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
|
||||
`or --external-ip-map ""`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringVar(&customDNSAddress, dnsResolverAddress, "",
|
||||
`Sets a custom address for NetBird's local DNS resolver. `+
|
||||
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||
|
||||
161
client/cmd/set.go
Normal file
161
client/cmd/set.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
setFlags = &SharedFlags{}
|
||||
|
||||
setCmd = &cobra.Command{
|
||||
Use: "set",
|
||||
Short: "Update NetBird client configuration",
|
||||
Long: `Update NetBird client configuration without connecting. Uses the same flags as 'netbird up' but only updates the configuration file.`,
|
||||
RunE: setFunc,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add all shared flags to the set command
|
||||
AddSharedFlags(setCmd, setFlags)
|
||||
// Note: We don't add up-only flags like --no-browser to set command
|
||||
}
|
||||
|
||||
func setFunc(cmd *cobra.Command, _ []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
// Validate inputs (reuse validation logic from up.go)
|
||||
if err := validateNATExternalIPs(setFlags.NATExternalIPs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||
if _, err := validateDnsLabels(setFlags.DNSLabels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var customDNSAddressConverted []byte
|
||||
if cmd.Flag(dnsResolverAddress).Changed {
|
||||
var err error
|
||||
customDNSAddressConverted, err = parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to daemon
|
||||
ctx := cmd.Context()
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
fmt.Printf("Warning: failed to close connection: %v\n", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SetConfigRequest{}
|
||||
|
||||
// Set fields based on changed flags
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
req.RosenpassEnabled = &setFlags.RosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
req.RosenpassPermissive = &setFlags.RosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &setFlags.ServerSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
req.DisableAutoConnect = &setFlags.AutoConnectDisabled
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
req.NetworkMonitor = &setFlags.NetworkMonitor
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(setFlags.InterfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
req.InterfaceName = &setFlags.InterfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
port := int64(setFlags.WireguardPort)
|
||||
req.WireguardPort = &port
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsResolverAddress).Changed {
|
||||
customAddr := string(customDNSAddressConverted)
|
||||
req.CustomDNSAddress = &customAddr
|
||||
}
|
||||
|
||||
if cmd.Flag(extraIFaceBlackListFlag).Changed {
|
||||
req.ExtraIFaceBlacklist = setFlags.ExtraIFaceBlackList
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||
req.DnsLabels = setFlags.DNSLabels
|
||||
req.CleanDNSLabels = &[]bool{setFlags.DNSLabels != nil && len(setFlags.DNSLabels) == 0}[0]
|
||||
}
|
||||
|
||||
if cmd.Flag(externalIPMapFlag).Changed {
|
||||
req.NatExternalIPs = setFlags.NATExternalIPs
|
||||
req.CleanNATExternalIPs = &[]bool{setFlags.NATExternalIPs != nil && len(setFlags.NATExternalIPs) == 0}[0]
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
req.DnsRouteInterval = durationpb.New(setFlags.DNSRouteInterval)
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
req.DisableClientRoutes = &setFlags.DisableClientRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
req.DisableServerRoutes = &setFlags.DisableServerRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
req.DisableDns = &setFlags.DisableDNS
|
||||
}
|
||||
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
req.DisableFirewall = &setFlags.DisableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
req.BlockLanAccess = &setFlags.BlockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
req.BlockInbound = &setFlags.BlockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
req.LazyConnectionEnabled = &setFlags.LazyConnEnabled
|
||||
}
|
||||
|
||||
// Send the request
|
||||
if _, err := client.SetConfig(ctx, req); err != nil {
|
||||
return fmt.Errorf("update configuration: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Configuration updated successfully")
|
||||
return nil
|
||||
}
|
||||
110
client/cmd/set_test.go
Normal file
110
client/cmd/set_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestParseBoolArg(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected bool
|
||||
hasError bool
|
||||
}{
|
||||
{"true", true, false},
|
||||
{"True", true, false},
|
||||
{"1", true, false},
|
||||
{"yes", true, false},
|
||||
{"on", true, false},
|
||||
{"false", false, false},
|
||||
{"False", false, false},
|
||||
{"0", false, false},
|
||||
{"no", false, false},
|
||||
{"off", false, false},
|
||||
{"invalid", false, true},
|
||||
{"maybe", false, true},
|
||||
{"", false, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
result, err := parseBoolArg(test.input)
|
||||
|
||||
if test.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for input %q, but got none", test.input)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for input %q: %v", test.input, err)
|
||||
}
|
||||
if result != test.expected {
|
||||
t.Errorf("For input %q, expected %v but got %v", test.input, test.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommandStructure(t *testing.T) {
|
||||
// Test that the set command has the expected subcommands
|
||||
expectedSubcommands := []string{
|
||||
"autoconnect",
|
||||
"ssh-server",
|
||||
"network-monitor",
|
||||
"rosenpass",
|
||||
"dns",
|
||||
"dns-interval",
|
||||
}
|
||||
|
||||
actualSubcommands := make([]string, 0, len(setCmd.Commands()))
|
||||
for _, cmd := range setCmd.Commands() {
|
||||
actualSubcommands = append(actualSubcommands, cmd.Name())
|
||||
}
|
||||
|
||||
if len(actualSubcommands) != len(expectedSubcommands) {
|
||||
t.Errorf("Expected %d subcommands, got %d", len(expectedSubcommands), len(actualSubcommands))
|
||||
}
|
||||
|
||||
for _, expected := range expectedSubcommands {
|
||||
found := false
|
||||
for _, actual := range actualSubcommands {
|
||||
if actual == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected subcommand %q not found", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommandUsage(t *testing.T) {
|
||||
if setCmd.Use != "set" {
|
||||
t.Errorf("Expected command use to be 'set', got %q", setCmd.Use)
|
||||
}
|
||||
|
||||
if setCmd.Short != "Set NetBird client configuration" {
|
||||
t.Errorf("Expected short description to be 'Set NetBird client configuration', got %q", setCmd.Short)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubcommandArgRequirements(t *testing.T) {
|
||||
// Test that all subcommands except dns-interval require exactly 1 argument
|
||||
subcommands := []*cobra.Command{
|
||||
setAutoconnectCmd,
|
||||
setSSHServerCmd,
|
||||
setNetworkMonitorCmd,
|
||||
setRosenpassCmd,
|
||||
setDNSCmd,
|
||||
setDNSIntervalCmd,
|
||||
}
|
||||
|
||||
for _, cmd := range subcommands {
|
||||
if cmd.Args == nil {
|
||||
t.Errorf("Command %q should have Args validation", cmd.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,7 +120,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
@@ -19,24 +19,3 @@ var (
|
||||
blockInbound bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add system flags to upCmd
|
||||
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
|
||||
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
|
||||
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
|
||||
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
124
client/cmd/up.go
124
client/cmd/up.go
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -42,6 +40,7 @@ var (
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
upFlags = &SharedFlags{}
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
@@ -51,26 +50,12 @@ var (
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add shared flags to up command
|
||||
AddSharedFlags(upCmd, upFlags)
|
||||
|
||||
// Add up-specific flags
|
||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels`+
|
||||
`You can specify a comma-separated list of up to 32 labels. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`,
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
AddUpOnlyFlags(upCmd, upFlags)
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -118,7 +103,16 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ic, err := setupConfig(customDNSAddressConverted, cmd)
|
||||
// Handle DNS labels validation and assignment to SharedFlags
|
||||
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||
var err error
|
||||
dnsLabelsValidated, err = validateDnsLabels(upFlags.DNSLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ic, err := BuildConfigInput(cmd, upFlags, customDNSAddressConverted)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup config: %v", err)
|
||||
}
|
||||
@@ -235,92 +229,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(wireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
if autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
if !autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
ic.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
|
||||
@@ -24,6 +24,7 @@ type WGTunDevice struct {
|
||||
mtu int
|
||||
iceBind *bind.ICEBind
|
||||
tunAdapter TunAdapter
|
||||
disableDNS bool
|
||||
|
||||
name string
|
||||
device *device.Device
|
||||
@@ -32,7 +33,7 @@ type WGTunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
routesString := routesToString(routes)
|
||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||
|
||||
// Skip DNS configuration when DisableDNS is enabled
|
||||
if t.disableDNS {
|
||||
log.Info("DNS is disabled, skipping DNS and search domain configuration")
|
||||
dns = ""
|
||||
searchDomainsToString = ""
|
||||
}
|
||||
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
|
||||
@@ -43,6 +43,7 @@ type WGIFaceOpts struct {
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn bind.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
// WGIface represents an interface instance
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
@@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||
if drop {
|
||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||
|
||||
if hasPortRestrictions {
|
||||
// Don't squash rules with port restrictions
|
||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = &protoMatch{
|
||||
ips: map[string]int{},
|
||||
|
||||
@@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []*mgmProto.FirewallRule
|
||||
expectedCount int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "should not squash rules with port ranges",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with specific ports",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with legacy port field",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with legacy port field should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with DROP action",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with DROP action should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should squash rules without port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 1,
|
||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||
},
|
||||
{
|
||||
name: "mixed rules should not squash protocol with port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "TCP should not be squashed because one rule has port restrictions",
|
||||
},
|
||||
{
|
||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
// TCP rules with port restrictions - should NOT be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
// UDP rules without port restrictions - SHOULD be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: tt.rules,
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
|
||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||
|
||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||
if tt.expectedCount == 1 {
|
||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
portInfo *mgmProto.PortInfo
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil PortInfo should be empty",
|
||||
portInfo: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero port should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 0,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with valid port should not be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with nil range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: nil,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero start range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 0,
|
||||
End: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero end range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 80,
|
||||
End: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with valid range should not be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := portInfoEmpty(tt.portInfo)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
|
||||
@@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
config := &Config{
|
||||
// defaults to false only for new (post 0.26) configurations
|
||||
ServerSSHAllowed: util.False(),
|
||||
// default to disabling server routes on Android for security
|
||||
DisableServerRoutes: runtime.GOOS == "android",
|
||||
}
|
||||
|
||||
if _, err := config.apply(input); err != nil {
|
||||
@@ -317,10 +319,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
*input.WireguardPort, config.WgPort)
|
||||
config.WgPort = *input.WireguardPort
|
||||
updated = true
|
||||
} else if config.WgPort == 0 {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
log.Infof("using default Wireguard port %d", config.WgPort)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
||||
@@ -416,9 +414,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
updated = true
|
||||
} else if config.ServerSSHAllowed == nil {
|
||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||
config.ServerSSHAllowed = util.True()
|
||||
if runtime.GOOS == "android" {
|
||||
// default to disabled SSH on Android for security
|
||||
log.Infof("setting SSH server to false by default on Android")
|
||||
config.ServerSSHAllowed = util.False()
|
||||
} else {
|
||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||
config.ServerSSHAllowed = util.True()
|
||||
}
|
||||
updated = true
|
||||
}
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
|
||||
PeerConnID: conn.ConnID(),
|
||||
Log: conn.Log,
|
||||
}
|
||||
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
|
||||
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
|
||||
if err != nil {
|
||||
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -526,17 +525,13 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
||||
|
||||
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
||||
func freePort(initPort int) (int, error) {
|
||||
addr := net.UDPAddr{}
|
||||
if initPort == 0 {
|
||||
initPort = iface.DefaultWgPort
|
||||
}
|
||||
|
||||
addr.Port = initPort
|
||||
addr := net.UDPAddr{Port: initPort}
|
||||
|
||||
conn, err := net.ListenUDP("udp", &addr)
|
||||
if err == nil {
|
||||
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
|
||||
closeConnWithLog(conn)
|
||||
return initPort, nil
|
||||
return returnPort, nil
|
||||
}
|
||||
|
||||
// if the port is already in use, ask the system for a free port
|
||||
|
||||
@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "not provided, fallback to default",
|
||||
name: "when port is 0 use random port",
|
||||
port: 0,
|
||||
want: 51820,
|
||||
shouldMatch: true,
|
||||
want: 0,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "provided and available",
|
||||
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
@@ -39,6 +39,14 @@ func Test_freePort(t *testing.T) {
|
||||
_ = c1.Close()
|
||||
}(c1)
|
||||
|
||||
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
|
||||
tests[1].port++
|
||||
tests[1].want++
|
||||
}
|
||||
|
||||
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
|
||||
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
DisableDNS: e.config.DisableDNS,
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
|
||||
@@ -1476,7 +1476,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -68,3 +68,8 @@ func (i *Monitor) PauseTimer() {
|
||||
func (i *Monitor) ResetTimer() {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
}
|
||||
|
||||
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
|
||||
i.Stop()
|
||||
go i.Start(ctx, timeoutChan)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ type Manager struct {
|
||||
// Route HA group management
|
||||
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
|
||||
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
|
||||
routesMu sync.RWMutex // protects route mappings
|
||||
routesMu sync.RWMutex
|
||||
|
||||
onInactive chan peerid.ConnID
|
||||
}
|
||||
@@ -146,7 +146,7 @@ func (m *Manager) Start(ctx context.Context) {
|
||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(ctx, peerConnID)
|
||||
case peerConnID := <-m.onInactive:
|
||||
m.onPeerInactivityTimedOut(peerConnID)
|
||||
m.onPeerInactivityTimedOut(ctx, peerConnID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -197,7 +197,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
|
||||
return added
|
||||
}
|
||||
|
||||
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -225,6 +225,13 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
peerCfg: &peerCfg,
|
||||
expectedWatcher: watcherActivity,
|
||||
}
|
||||
|
||||
// Check if this peer should be activated because its HA group peers are active
|
||||
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
|
||||
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
|
||||
m.activateNewPeerInActiveGroup(ctx, peerCfg)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -315,36 +322,38 @@ func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConf
|
||||
|
||||
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
|
||||
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
|
||||
var peersToActivate []string
|
||||
|
||||
m.routesMu.RLock()
|
||||
haGroups := m.peerToHAGroups[triggerPeerID]
|
||||
m.routesMu.RUnlock()
|
||||
|
||||
if len(haGroups) == 0 {
|
||||
m.routesMu.RUnlock()
|
||||
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
|
||||
return
|
||||
}
|
||||
|
||||
activatedCount := 0
|
||||
for _, haGroup := range haGroups {
|
||||
m.routesMu.RLock()
|
||||
peers := m.haGroupToPeers[haGroup]
|
||||
m.routesMu.RUnlock()
|
||||
|
||||
for _, peerID := range peers {
|
||||
if peerID == triggerPeerID {
|
||||
continue
|
||||
if peerID != triggerPeerID {
|
||||
peersToActivate = append(peersToActivate, peerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.routesMu.RUnlock()
|
||||
|
||||
cfg, mp := m.getPeerForActivation(peerID)
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
activatedCount := 0
|
||||
for _, peerID := range peersToActivate {
|
||||
cfg, mp := m.getPeerForActivation(peerID)
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if m.activateSinglePeer(ctx, cfg, mp) {
|
||||
activatedCount++
|
||||
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID)
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
|
||||
}
|
||||
if m.activateSinglePeer(ctx, cfg, mp) {
|
||||
activatedCount++
|
||||
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,6 +363,51 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string
|
||||
}
|
||||
}
|
||||
|
||||
// shouldActivateNewPeer checks if a newly added peer should be activated
|
||||
// because other peers in its HA groups are already active
|
||||
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
|
||||
m.routesMu.RLock()
|
||||
defer m.routesMu.RUnlock()
|
||||
|
||||
haGroups := m.peerToHAGroups[peerID]
|
||||
if len(haGroups) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, haGroup := range haGroups {
|
||||
peers := m.haGroupToPeers[haGroup]
|
||||
for _, groupPeerID := range peers {
|
||||
if groupPeerID == peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, ok := m.managedPeers[groupPeerID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
|
||||
return haGroup, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
|
||||
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
|
||||
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
|
||||
return
|
||||
}
|
||||
|
||||
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
|
||||
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||
peerCfg.Log.Warnf("peer already managed")
|
||||
@@ -415,6 +469,48 @@ func (m *Manager) close() {
|
||||
log.Infof("lazy connection manager closed")
|
||||
}
|
||||
|
||||
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
|
||||
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
|
||||
m.routesMu.RLock()
|
||||
defer m.routesMu.RUnlock()
|
||||
|
||||
haGroups := m.peerToHAGroups[peerID]
|
||||
if len(haGroups) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, haGroup := range haGroups {
|
||||
groupPeers := m.haGroupToPeers[haGroup]
|
||||
|
||||
for _, groupPeerID := range groupPeers {
|
||||
if groupPeerID == peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, ok := m.managedPeers[groupPeerID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if groupMp.expectedWatcher != watcherInactivity {
|
||||
continue
|
||||
}
|
||||
|
||||
// Other member is still connected, defer idle
|
||||
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
@@ -441,7 +537,7 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
|
||||
func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -456,6 +552,17 @@ func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
|
||||
iw, ok := m.inactivityMonitors[peerConnID]
|
||||
if ok {
|
||||
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
|
||||
iw.ResetMonitor(ctx, m.onInactive)
|
||||
} else {
|
||||
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("connection timed out")
|
||||
|
||||
// this is blocking operation, potentially can be optimized
|
||||
@@ -489,7 +596,7 @@ func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
|
||||
mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -317,12 +317,12 @@ func (conn *Conn) WgConfig() WgConfig {
|
||||
return conn.config.WgConfig
|
||||
}
|
||||
|
||||
// IsConnected unit tests only
|
||||
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
|
||||
// IsConnected returns true if the peer is connected
|
||||
func (conn *Conn) IsConnected() bool {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
return conn.currentConnPriority != conntype.None
|
||||
|
||||
return conn.evalStatus() == StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) GetKey() string {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
|
||||
@@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
nets := make([]string, 0)
|
||||
for _, r := range clientRoutes {
|
||||
// filter out domain routes
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
@@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
if runtime.GOOS != "android" {
|
||||
return
|
||||
}
|
||||
newNets := make([]string, 0)
|
||||
|
||||
var newNets []string
|
||||
for _, routes := range idMap {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
newNets = append(newNets, r.Network.String())
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
|
||||
n.notify()
|
||||
}
|
||||
|
||||
// OnNewPrefixes is called from iOS only
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
for _, prefix := range prefixes {
|
||||
@@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
|
||||
n.notify()
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,10 @@ func (n Nexthop) String() string {
|
||||
if n.Intf == nil {
|
||||
return n.IP.String()
|
||||
}
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
if n.IP.IsValid() {
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
return fmt.Sprintf("no-ip @ %d (%s)", n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
|
||||
type wgIface interface {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
<Wix
|
||||
xmlns="http://wixtoolset.org/schemas/v4/wxs">
|
||||
xmlns="http://wixtoolset.org/schemas/v4/wxs"
|
||||
xmlns:util="http://wixtoolset.org/schemas/v4/wxs/util">
|
||||
<Package Name="NetBird" Version="$(env.NETBIRD_VERSION)" Manufacturer="NetBird GmbH" Language="1033" UpgradeCode="6456ec4e-3ad6-4b9b-a2be-98e81cb21ccf"
|
||||
InstallerVersion="500" Compressed="yes" Codepage="utf-8" >
|
||||
|
||||
|
||||
<MediaTemplate EmbedCab="yes" />
|
||||
|
||||
<Feature Id="NetbirdFeature" Title="Netbird" Level="1">
|
||||
@@ -46,29 +48,10 @@
|
||||
<ComponentRef Id="NetbirdFiles" />
|
||||
</ComponentGroup>
|
||||
|
||||
<Property Id="cmd" Value="cmd.exe"/>
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" />
|
||||
|
||||
<CustomAction Id="KillDaemon"
|
||||
ExeCommand='/c "taskkill /im netbird.exe"'
|
||||
Execute="deferred"
|
||||
Property="cmd"
|
||||
Impersonate="no"
|
||||
Return="ignore"
|
||||
/>
|
||||
|
||||
<CustomAction Id="KillUI"
|
||||
ExeCommand='/c "taskkill /im netbird-ui.exe"'
|
||||
Execute="deferred"
|
||||
Property="cmd"
|
||||
Impersonate="no"
|
||||
Return="ignore"
|
||||
/>
|
||||
|
||||
<InstallExecuteSequence>
|
||||
<!-- For Uninstallation -->
|
||||
<Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/>
|
||||
<Custom Action="KillUI" After="KillDaemon" Condition="Installed"/>
|
||||
</InstallExecuteSequence>
|
||||
|
||||
<!-- Icons -->
|
||||
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -67,6 +67,9 @@ service DaemonService {
|
||||
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
|
||||
|
||||
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
|
||||
|
||||
// SetConfig updates daemon configuration without reconnecting
|
||||
rpc SetConfig(SetConfigRequest) returns (SetConfigResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -158,6 +161,7 @@ message UpResponse {}
|
||||
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
}
|
||||
|
||||
message StatusResponse{
|
||||
@@ -495,3 +499,29 @@ message GetEventsRequest {}
|
||||
message GetEventsResponse {
|
||||
repeated SystemEvent events = 1;
|
||||
}
|
||||
|
||||
message SetConfigRequest {
|
||||
optional bool rosenpassEnabled = 1;
|
||||
optional bool rosenpassPermissive = 2;
|
||||
optional bool serverSSHAllowed = 3;
|
||||
optional bool disableAutoConnect = 4;
|
||||
optional bool networkMonitor = 5;
|
||||
optional google.protobuf.Duration dnsRouteInterval = 6;
|
||||
optional bool disable_client_routes = 7;
|
||||
optional bool disable_server_routes = 8;
|
||||
optional bool disable_dns = 9;
|
||||
optional bool disable_firewall = 10;
|
||||
optional bool block_lan_access = 11;
|
||||
optional bool lazyConnectionEnabled = 12;
|
||||
optional bool block_inbound = 13;
|
||||
optional string interfaceName = 14;
|
||||
optional int64 wireguardPort = 15;
|
||||
optional string customDNSAddress = 16;
|
||||
repeated string extraIFaceBlacklist = 17;
|
||||
repeated string dns_labels = 18;
|
||||
optional bool cleanDNSLabels = 19;
|
||||
repeated string natExternalIPs = 20;
|
||||
optional bool cleanNATExternalIPs = 21;
|
||||
}
|
||||
|
||||
message SetConfigResponse {}
|
||||
|
||||
@@ -55,6 +55,8 @@ type DaemonServiceClient interface {
|
||||
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
|
||||
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error)
|
||||
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
|
||||
// SetConfig updates daemon configuration without reconnecting
|
||||
SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -268,6 +270,15 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) {
|
||||
out := new(SetConfigResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetConfig", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -309,6 +320,8 @@ type DaemonServiceServer interface {
|
||||
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
|
||||
SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error
|
||||
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
|
||||
// SetConfig updates daemon configuration without reconnecting
|
||||
SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -376,6 +389,9 @@ func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, Daemo
|
||||
func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SetConfig not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -752,6 +768,24 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_SetConfig_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SetConfigRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).SetConfig(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/SetConfig",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).SetConfig(ctx, req.(*SetConfigRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -835,6 +869,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetEvents",
|
||||
Handler: _DaemonService_GetEvents_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SetConfig",
|
||||
Handler: _DaemonService_SetConfig_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
@@ -707,7 +707,9 @@ func (s *Server) Status(
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
if msg.GetFullPeerStatus {
|
||||
s.runProbes()
|
||||
if msg.ShouldRunProbes {
|
||||
s.runProbes()
|
||||
}
|
||||
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
@@ -797,6 +799,133 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetConfig updates daemon configuration without reconnecting
|
||||
func (s *Server) SetConfig(ctx context.Context, req *proto.SetConfigRequest) (*proto.SetConfigResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.config == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "daemon is not configured")
|
||||
}
|
||||
|
||||
configChanged := false
|
||||
|
||||
if req.RosenpassEnabled != nil && s.config.RosenpassEnabled != *req.RosenpassEnabled {
|
||||
s.config.RosenpassEnabled = *req.RosenpassEnabled
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.RosenpassPermissive != nil && s.config.RosenpassPermissive != *req.RosenpassPermissive {
|
||||
s.config.RosenpassPermissive = *req.RosenpassPermissive
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.ServerSSHAllowed != nil && s.config.ServerSSHAllowed != nil && *s.config.ServerSSHAllowed != *req.ServerSSHAllowed {
|
||||
*s.config.ServerSSHAllowed = *req.ServerSSHAllowed
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DisableAutoConnect != nil && s.config.DisableAutoConnect != *req.DisableAutoConnect {
|
||||
s.config.DisableAutoConnect = *req.DisableAutoConnect
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.NetworkMonitor != nil && s.config.NetworkMonitor != nil && *s.config.NetworkMonitor != *req.NetworkMonitor {
|
||||
*s.config.NetworkMonitor = *req.NetworkMonitor
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DnsRouteInterval != nil {
|
||||
duration := req.DnsRouteInterval.AsDuration()
|
||||
if s.config.DNSRouteInterval != duration {
|
||||
s.config.DNSRouteInterval = duration
|
||||
configChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if req.DisableClientRoutes != nil && s.config.DisableClientRoutes != *req.DisableClientRoutes {
|
||||
s.config.DisableClientRoutes = *req.DisableClientRoutes
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DisableServerRoutes != nil && s.config.DisableServerRoutes != *req.DisableServerRoutes {
|
||||
s.config.DisableServerRoutes = *req.DisableServerRoutes
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DisableDns != nil && s.config.DisableDNS != *req.DisableDns {
|
||||
s.config.DisableDNS = *req.DisableDns
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DisableFirewall != nil && s.config.DisableFirewall != *req.DisableFirewall {
|
||||
s.config.DisableFirewall = *req.DisableFirewall
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.BlockLanAccess != nil && s.config.BlockLANAccess != *req.BlockLanAccess {
|
||||
s.config.BlockLANAccess = *req.BlockLanAccess
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.LazyConnectionEnabled != nil && s.config.LazyConnectionEnabled != *req.LazyConnectionEnabled {
|
||||
s.config.LazyConnectionEnabled = *req.LazyConnectionEnabled
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.BlockInbound != nil && s.config.BlockInbound != *req.BlockInbound {
|
||||
s.config.BlockInbound = *req.BlockInbound
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.InterfaceName != nil && s.config.WgIface != *req.InterfaceName {
|
||||
s.config.WgIface = *req.InterfaceName
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.WireguardPort != nil && s.config.WgPort != int(*req.WireguardPort) {
|
||||
s.config.WgPort = int(*req.WireguardPort)
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.CustomDNSAddress != nil && s.config.CustomDNSAddress != *req.CustomDNSAddress {
|
||||
s.config.CustomDNSAddress = *req.CustomDNSAddress
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if len(req.ExtraIFaceBlacklist) > 0 {
|
||||
s.config.IFaceBlackList = req.ExtraIFaceBlacklist
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if len(req.DnsLabels) > 0 || (req.CleanDNSLabels != nil && *req.CleanDNSLabels) {
|
||||
if req.CleanDNSLabels != nil && *req.CleanDNSLabels {
|
||||
s.config.DNSLabels = domain.List{}
|
||||
} else {
|
||||
var err error
|
||||
s.config.DNSLabels, err = domain.FromStringList(req.DnsLabels)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid DNS labels: %v", err)
|
||||
}
|
||||
}
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if len(req.NatExternalIPs) > 0 || (req.CleanNATExternalIPs != nil && *req.CleanNATExternalIPs) {
|
||||
s.config.NATExternalIPs = req.NatExternalIPs
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if configChanged {
|
||||
if err := internal.WriteOutConfig(s.latestConfigInput.ConfigPath, s.config); err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "write config: %v", err)
|
||||
}
|
||||
log.Debug("Configuration updated successfully")
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) onSessionExpire() {
|
||||
if runtime.GOOS != "windows" {
|
||||
isUIActive := internal.CheckUIApp()
|
||||
|
||||
@@ -206,7 +206,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -59,16 +59,16 @@ type Info struct {
|
||||
Environment Environment
|
||||
Files []File // for posture checks
|
||||
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool
|
||||
|
||||
showAdvancedSettings: showSettings,
|
||||
showNetworks: showNetworks,
|
||||
update: version.NewUpdate(),
|
||||
update: version.NewUpdate("nb/client-ui"),
|
||||
}
|
||||
|
||||
s.eventHandler = newEventHandler(s)
|
||||
@@ -879,7 +879,7 @@ func (s *serviceClient) onUpdateAvailable() {
|
||||
func (s *serviceClient) onSessionExpire() {
|
||||
s.sendNotification = true
|
||||
if s.sendNotification {
|
||||
s.eventHandler.runSelfCommand("login-url", "true")
|
||||
go s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
|
||||
s.sendNotification = false
|
||||
}
|
||||
}
|
||||
@@ -992,21 +992,6 @@ func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
|
||||
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
|
||||
func (s *serviceClient) showLoginURL() {
|
||||
|
||||
resp, err := s.login(false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to fetch login URL: %v", err)
|
||||
return
|
||||
}
|
||||
verificationURL := resp.VerificationURIComplete
|
||||
if verificationURL == "" {
|
||||
verificationURL = resp.VerificationURI
|
||||
}
|
||||
|
||||
if verificationURL == "" {
|
||||
log.Error("no verification URL provided in the login response")
|
||||
return
|
||||
}
|
||||
|
||||
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
|
||||
|
||||
if s.wLoginURL == nil {
|
||||
@@ -1025,6 +1010,21 @@ func (s *serviceClient) showLoginURL() {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.login(false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to fetch login URL: %v", err)
|
||||
return
|
||||
}
|
||||
verificationURL := resp.VerificationURIComplete
|
||||
if verificationURL == "" {
|
||||
verificationURL = resp.VerificationURI
|
||||
}
|
||||
|
||||
if verificationURL == "" {
|
||||
log.Error("no verification URL provided in the login response")
|
||||
return
|
||||
}
|
||||
|
||||
if err := openURL(verificationURL); err != nil {
|
||||
log.Errorf("failed to open login URL: %v", err)
|
||||
return
|
||||
@@ -1038,7 +1038,19 @@ func (s *serviceClient) showLoginURL() {
|
||||
}
|
||||
|
||||
label.SetText("Re-authentication successful.\nReconnecting")
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
label.SetText("Already connected.\nClosing this window.")
|
||||
time.Sleep(2 * time.Second)
|
||||
s.wLoginURL.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
|
||||
|
||||
@@ -122,7 +122,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
|
||||
go func() {
|
||||
defer h.client.mAdvancedSettings.Enable()
|
||||
defer h.client.getSrvConfig()
|
||||
h.runSelfCommand("settings", "true")
|
||||
h.runSelfCommand(h.client.ctx, "settings", "true")
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
|
||||
h.client.mCreateDebugBundle.Disable()
|
||||
go func() {
|
||||
defer h.client.mCreateDebugBundle.Enable()
|
||||
h.runSelfCommand("debug", "true")
|
||||
h.runSelfCommand(h.client.ctx, "debug", "true")
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -154,7 +154,7 @@ func (h *eventHandler) handleNetworksClick() {
|
||||
h.client.mNetworks.Disable()
|
||||
go func() {
|
||||
defer h.client.mNetworks.Enable()
|
||||
h.runSelfCommand("networks", "true")
|
||||
h.runSelfCommand(h.client.ctx, "networks", "true")
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -172,14 +172,14 @@ func (h *eventHandler) updateConfigWithErr() {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *eventHandler) runSelfCommand(command, arg string) {
|
||||
func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
|
||||
proc, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Errorf("error getting executable path: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command(proc,
|
||||
cmd := exec.CommandContext(ctx, proc,
|
||||
fmt.Sprintf("--%s=%s", command, arg),
|
||||
fmt.Sprintf("--daemon-addr=%s", h.client.addr),
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAI
|
||||
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
||||
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
|
||||
|
||||
# Signal
|
||||
NETBIRD_SIGNAL_PROTOCOL="http"
|
||||
@@ -60,7 +61,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
|
||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
|
||||
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1}
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-0}
|
||||
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
|
||||
# Dashboard
|
||||
@@ -139,3 +140,4 @@ export NETBIRD_RELAY_PORT
|
||||
export NETBIRD_RELAY_ENDPOINT
|
||||
export NETBIRD_RELAY_AUTH_SECRET
|
||||
export NETBIRD_RELAY_TAG
|
||||
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
@@ -791,7 +791,6 @@ services:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
- '8080:8080'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
"0.0.0.0/0"
|
||||
]
|
||||
},
|
||||
"DisableDefaultPolicy": $NETBIRD_MGMT_DISABLE_DEFAULT_POLICY,
|
||||
"Datadir": "",
|
||||
"DataStoreEncryptionKey": "$NETBIRD_DATASTORE_ENC_KEY",
|
||||
"StoreConfig": {
|
||||
|
||||
@@ -92,7 +92,8 @@ NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
||||
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
|
||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||
|
||||
# Disable default all-to-all policy for new accounts
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
|
||||
# -------------------------------------------
|
||||
# Relay settings
|
||||
# -------------------------------------------
|
||||
|
||||
@@ -29,3 +29,4 @@ NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
|
||||
NETBIRD_RELAY_PORT=33445
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
@@ -100,7 +100,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ var (
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
proxyController := integrations.NewController(store)
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager)
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
}
|
||||
@@ -357,6 +357,13 @@ var (
|
||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
|
||||
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
|
||||
|
||||
update := version.NewUpdate("nb/management")
|
||||
update.SetDaemonVersion(version.NetbirdVersion())
|
||||
update.SetOnUpdateListener(func() {
|
||||
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
|
||||
})
|
||||
defer update.StopWatch()
|
||||
|
||||
SetupCloseHandler()
|
||||
|
||||
<-stopCh
|
||||
|
||||
@@ -102,6 +102,8 @@ type DefaultAccountManager struct {
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
|
||||
disableDefaultPolicy bool
|
||||
}
|
||||
|
||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||
@@ -170,6 +172,7 @@ func BuildManager(
|
||||
proxyController port_forwarding.Controller,
|
||||
settingsManager settings.Manager,
|
||||
permissionsManager permissions.Manager,
|
||||
disableDefaultPolicy bool,
|
||||
) (*DefaultAccountManager, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -195,6 +198,7 @@ func BuildManager(
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
@@ -543,7 +547,7 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
case statusErr.Type() == status.NotFound:
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain)
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
default:
|
||||
@@ -1688,7 +1692,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
@@ -1731,7 +1735,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
@@ -1833,7 +1837,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
},
|
||||
}
|
||||
|
||||
if err := newAccount.AddAllGroup(); err != nil {
|
||||
if err := newAccount.AddAllGroup(am.disableDefaultPolicy); err != nil {
|
||||
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
||||
}
|
||||
|
||||
@@ -1853,40 +1857,49 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
var account *types.Account
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
account, err = transaction.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if account.IsDomainPrimaryAccount {
|
||||
return nil
|
||||
}
|
||||
|
||||
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
|
||||
|
||||
// error is not a not found error
|
||||
if handleNotFound(err) != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// a primary account already exists for this private domain
|
||||
if err == nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountId": accountId,
|
||||
"existingAccountId": existingPrimaryAccountID,
|
||||
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
|
||||
return status.Errorf(status.Internal, "cannot update account to primary")
|
||||
}
|
||||
|
||||
account.IsDomainPrimaryAccount = true
|
||||
|
||||
if err := transaction.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountId": accountId,
|
||||
}).Errorf("failed to update account to primary: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to update account to primary")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.IsDomainPrimaryAccount {
|
||||
return account, nil
|
||||
}
|
||||
|
||||
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
|
||||
|
||||
// error is not a not found error
|
||||
if handleNotFound(err) != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// a primary account already exists for this private domain
|
||||
if err == nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountId": accountId,
|
||||
"existingAccountId": existingPrimaryAccountID,
|
||||
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
|
||||
return nil, status.Errorf(status.Internal, "cannot update account to primary")
|
||||
}
|
||||
|
||||
account.IsDomainPrimaryAccount = true
|
||||
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountId": accountId,
|
||||
}).Errorf("failed to update account to primary: %v", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to update account to primary")
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -398,7 +398,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -640,7 +640,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
@@ -793,7 +793,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2879,7 +2879,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -217,7 +217,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -267,7 +267,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false)
|
||||
|
||||
account.Users[dnsRegularUserID] = &types.User{
|
||||
Id: dnsRegularUserID,
|
||||
|
||||
@@ -127,7 +127,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
|
||||
@@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
|
||||
|
||||
@@ -369,7 +369,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
|
||||
@@ -426,6 +426,10 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "stage-host-1"
|
||||
ephemeral:
|
||||
description: Indicates whether the peer is ephemeral or not
|
||||
type: boolean
|
||||
example: false
|
||||
required:
|
||||
- city_name
|
||||
- connected
|
||||
@@ -450,6 +454,7 @@ components:
|
||||
- approval_required
|
||||
- serial_number
|
||||
- extra_dns_labels
|
||||
- ephemeral
|
||||
AccessiblePeer:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PeerMinimum'
|
||||
|
||||
@@ -1016,6 +1016,9 @@ type Peer struct {
|
||||
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DnsLabel string `json:"dns_label"`
|
||||
|
||||
// Ephemeral Indicates whether the peer is ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// ExtraDnsLabels Extra DNS labels added to the peer
|
||||
ExtraDnsLabels []string `json:"extra_dns_labels"`
|
||||
|
||||
@@ -1097,6 +1100,9 @@ type PeerBatch struct {
|
||||
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DnsLabel string `json:"dns_label"`
|
||||
|
||||
// Ephemeral Indicates whether the peer is ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// ExtraDnsLabels Extra DNS labels added to the peer
|
||||
ExtraDnsLabels []string `json:"extra_dns_labels"`
|
||||
|
||||
|
||||
@@ -365,6 +365,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
package testing_tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
@@ -138,7 +137,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -37,21 +37,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
a, err := am.Store.GetAccountByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
a, err := transaction.GetAccountByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var extra *types.ExtraSettings
|
||||
var extra *types.ExtraSettings
|
||||
|
||||
if a.Settings.Extra != nil {
|
||||
extra = a.Settings.Extra
|
||||
} else {
|
||||
extra = &types.ExtraSettings{}
|
||||
a.Settings.Extra = extra
|
||||
}
|
||||
extra.IntegratedValidatorGroups = groups
|
||||
return am.Store.SaveAccount(ctx, a)
|
||||
if a.Settings.Extra != nil {
|
||||
extra = a.Settings.Extra
|
||||
} else {
|
||||
extra = &types.ExtraSettings{}
|
||||
a.Settings.Extra = extra
|
||||
}
|
||||
extra.IntegratedValidatorGroups = groups
|
||||
return transaction.SaveAccount(ctx, a)
|
||||
})
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||
@@ -81,15 +83,12 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
||||
var peers []*nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err = transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
return err
|
||||
})
|
||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -444,7 +444,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
if err != nil {
|
||||
cleanup()
|
||||
|
||||
@@ -211,7 +211,7 @@ func startServer(
|
||||
port_forwarding.NewControllerMock(),
|
||||
settingsMockManager,
|
||||
permissionsManager,
|
||||
)
|
||||
false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating an account manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -184,7 +184,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
ephemeralPeersSKs int
|
||||
ephemeralPeersSKUsage int
|
||||
activePeersLastDay int
|
||||
activeUserPeersLastDay int
|
||||
osPeers map[string]int
|
||||
activeUsersLastDay map[string]struct{}
|
||||
userPeers int
|
||||
rules int
|
||||
rulesProtocol map[string]int
|
||||
@@ -203,6 +205,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
version string
|
||||
peerActiveVersions []string
|
||||
osUIClients map[string]int
|
||||
rosenpassEnabled int
|
||||
)
|
||||
start := time.Now()
|
||||
metricsProperties := make(properties)
|
||||
@@ -210,6 +213,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
osUIClients = make(map[string]int)
|
||||
rulesProtocol = make(map[string]int)
|
||||
rulesDirection = make(map[string]int)
|
||||
activeUsersLastDay = make(map[string]struct{})
|
||||
uptime = time.Since(w.startupTime).Seconds()
|
||||
connections := w.connManager.GetAllConnectedPeers()
|
||||
version = nbversion.NetbirdVersion()
|
||||
@@ -277,10 +281,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
for _, peer := range account.Peers {
|
||||
peers++
|
||||
|
||||
if peer.SSHEnabled {
|
||||
if peer.SSHEnabled || peer.Meta.Flags.ServerSSHAllowed {
|
||||
peersSSHEnabled++
|
||||
}
|
||||
|
||||
if peer.Meta.Flags.RosenpassEnabled {
|
||||
rosenpassEnabled++
|
||||
}
|
||||
|
||||
if peer.UserID != "" {
|
||||
userPeers++
|
||||
}
|
||||
@@ -299,6 +307,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
_, connected := connections[peer.ID]
|
||||
if connected || peer.Status.LastSeen.After(w.lastRun) {
|
||||
activePeersLastDay++
|
||||
if peer.UserID != "" {
|
||||
activeUserPeersLastDay++
|
||||
activeUsersLastDay[peer.UserID] = struct{}{}
|
||||
}
|
||||
osActiveKey := osKey + "_active"
|
||||
osActiveCount := osPeers[osActiveKey]
|
||||
osPeers[osActiveKey] = osActiveCount + 1
|
||||
@@ -320,6 +332,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["ephemeral_peers_setup_keys"] = ephemeralPeersSKs
|
||||
metricsProperties["ephemeral_peers_setup_keys_usage"] = ephemeralPeersSKUsage
|
||||
metricsProperties["active_peers_last_day"] = activePeersLastDay
|
||||
metricsProperties["active_user_peers_last_day"] = activeUserPeersLastDay
|
||||
metricsProperties["active_users_last_day"] = len(activeUsersLastDay)
|
||||
metricsProperties["user_peers"] = userPeers
|
||||
metricsProperties["rules"] = rules
|
||||
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
|
||||
@@ -338,6 +352,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["ui_clients"] = uiClient
|
||||
metricsProperties["idp_manager"] = w.idpManager
|
||||
metricsProperties["store_engine"] = w.dataSource.GetStoreEngine()
|
||||
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
|
||||
|
||||
for protocol, count := range rulesProtocol {
|
||||
metricsProperties["rules_protocol_"+protocol] = count
|
||||
|
||||
@@ -47,8 +47,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
"1": {
|
||||
ID: "1",
|
||||
UserID: "test",
|
||||
SSHEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
|
||||
SSHEnabled: false,
|
||||
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1", Flags: nbpeer.Flags{ServerSSHAllowed: true, RosenpassEnabled: true}},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
@@ -312,7 +312,19 @@ func TestGenerateProperties(t *testing.T) {
|
||||
}
|
||||
|
||||
if properties["posture_checks"] != 2 {
|
||||
t.Errorf("expected 1 posture_checks, got %d", properties["posture_checks"])
|
||||
t.Errorf("expected 2 posture_checks, got %d", properties["posture_checks"])
|
||||
}
|
||||
|
||||
if properties["rosenpass_enabled"] != 1 {
|
||||
t.Errorf("expected 1 rosenpass_enabled, got %d", properties["rosenpass_enabled"])
|
||||
}
|
||||
|
||||
if properties["active_user_peers_last_day"] != 2 {
|
||||
t.Errorf("expected 2 active_user_peers_last_day, got %d", properties["active_user_peers_last_day"])
|
||||
}
|
||||
|
||||
if properties["active_users_last_day"] != 1 {
|
||||
t.Errorf("expected 1 active_users_last_day, got %d", properties["active_users_last_day"])
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -779,7 +779,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -848,7 +848,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
@@ -1149,7 +1149,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap)
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peer.ID {
|
||||
return peer, nil
|
||||
@@ -1169,7 +1169,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
globalStart := time.Now()
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
@@ -1204,18 +1204,27 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
||||
|
||||
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -1223,7 +1232,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
return
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting)
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}(peer)
|
||||
}
|
||||
@@ -1232,7 +1244,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
|
||||
wg.Wait()
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -480,7 +480,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -667,7 +667,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: testCase.role,
|
||||
@@ -737,7 +737,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[regularUser] = &types.User{
|
||||
Id: regularUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1267,7 +1267,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1342,7 +1342,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1477,7 +1477,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1546,7 +1546,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -2052,7 +2052,7 @@ func Test_DeletePeer(t *testing.T) {
|
||||
// account with an admin and a regular user
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Peers = map[string]*nbpeer.Peer{
|
||||
"peer1": {
|
||||
ID: "peer1",
|
||||
|
||||
@@ -27,6 +27,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
ID: "peerB",
|
||||
IP: net.ParseIP("100.65.80.39"),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"},
|
||||
},
|
||||
"peerC": {
|
||||
ID: "peerC",
|
||||
@@ -63,6 +64,12 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
IP: net.ParseIP("100.65.31.2"),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
},
|
||||
"peerK": {
|
||||
ID: "peerK",
|
||||
IP: net.ParseIP("100.32.80.1"),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"GroupAll": {
|
||||
@@ -111,6 +118,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
"peerI",
|
||||
},
|
||||
},
|
||||
"GroupWorkflow": {
|
||||
ID: "GroupWorkflow",
|
||||
Name: "workflow",
|
||||
Peers: []string{
|
||||
"peerK",
|
||||
},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
@@ -189,6 +203,39 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "RuleWorkflow",
|
||||
Name: "Workflow",
|
||||
Description: "No description",
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleWorkflow",
|
||||
Name: "Workflow",
|
||||
Description: "No description",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
PortRanges: []types.RulePortRange{
|
||||
{
|
||||
Start: 8088,
|
||||
End: 8088,
|
||||
},
|
||||
{
|
||||
Start: 9090,
|
||||
End: 9095,
|
||||
},
|
||||
},
|
||||
Sources: []string{
|
||||
"GroupWorkflow",
|
||||
},
|
||||
Destinations: []string{
|
||||
"GroupDMZ",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -199,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers)
|
||||
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
|
||||
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
|
||||
assert.Len(t, peers, 8)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
@@ -364,6 +411,32 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
assert.True(t, contains, "rule not found in expected rules %#v", rule)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
|
||||
assert.Len(t, peers, 1)
|
||||
assert.Contains(t, peers, account.Peers["peerI"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.31.2",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "8088",
|
||||
PolicyID: "RuleWorkflow",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.31.2",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "8088",
|
||||
PolicyID: "RuleWorkflow",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, firewallRules, expectedFirewallRules)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
@@ -466,10 +539,10 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check first peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
@@ -487,19 +560,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check second peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
@@ -517,21 +590,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
|
||||
account.Policies[1].Rules[0].Bidirectional = false
|
||||
|
||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
@@ -541,19 +614,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
@@ -563,11 +636,11 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -748,7 +821,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||
// will establish a connection with all source peers satisfying the NB posture check.
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -758,7 +831,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 1)
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -775,7 +848,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -785,7 +858,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -800,19 +873,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||
|
||||
@@ -827,14 +900,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||
assert.Len(t, peers, 3)
|
||||
assert.Len(t, firewallRules, 3)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
|
||||
assert.Len(t, peers, 5)
|
||||
// assert peers from Group Swarm
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
@@ -24,20 +24,12 @@ func sanitizeVersion(version string) string {
|
||||
}
|
||||
|
||||
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
|
||||
peerVersion := sanitizeVersion(peer.Meta.WtVersion)
|
||||
minVersion := sanitizeVersion(n.MinVersion)
|
||||
|
||||
peerNBVersion, err := version.NewVersion(peerVersion)
|
||||
meetsMin, err := MeetsMinVersion(n.MinVersion, peer.Meta.WtVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
constraints, err := version.NewConstraint(">= " + minVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if constraints.Check(peerNBVersion) {
|
||||
if meetsMin {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -60,3 +52,21 @@ func (n *NBVersionCheck) Validate() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MeetsMinVersion checks if the peer's version meets or exceeds the minimum required version
|
||||
func MeetsMinVersion(minVer, peerVer string) (bool, error) {
|
||||
peerVer = sanitizeVersion(peerVer)
|
||||
minVer = sanitizeVersion(minVer)
|
||||
|
||||
peerNBVer, err := version.NewVersion(peerVer)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
constraints, err := version.NewConstraint(">= " + minVer)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return constraints.Check(peerNBVer), nil
|
||||
}
|
||||
|
||||
@@ -139,3 +139,68 @@ func TestNBVersionCheck_Validate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeetsMinVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
minVer string
|
||||
peerVer string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Peer version greater than min version",
|
||||
minVer: "0.26.0",
|
||||
peerVer: "0.60.1",
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Peer version equals min version",
|
||||
minVer: "1.0.0",
|
||||
peerVer: "1.0.0",
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Peer version less than min version",
|
||||
minVer: "1.0.0",
|
||||
peerVer: "0.9.9",
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Peer version with pre-release tag greater than min version",
|
||||
minVer: "1.0.0",
|
||||
peerVer: "1.0.1-alpha",
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid peer version format",
|
||||
minVer: "1.0.0",
|
||||
peerVer: "dev",
|
||||
want: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid min version format",
|
||||
minVer: "invalid.version",
|
||||
peerVer: "1.0.0",
|
||||
want: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := MeetsMinVersion(tt.minVer, tt.peerVer)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID)
|
||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
|
||||
}
|
||||
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error {
|
||||
// routes can have both peer and peer_groups
|
||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
prefix := checkRoute.Network
|
||||
domains := checkRoute.Domains
|
||||
|
||||
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||
seenPeers := make(map[string]bool)
|
||||
@@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
for _, prefixRoute := range routesWithPrefix {
|
||||
// we skip route(s) with the same network ID as we want to allow updating of the existing route
|
||||
// when creating a new route routeID is newly generated so nothing will be skipped
|
||||
if routeID == prefixRoute.ID {
|
||||
if checkRoute.ID == prefixRoute.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefixRoute.Peer != "" {
|
||||
seenPeers[string(prefixRoute.ID)] = true
|
||||
}
|
||||
|
||||
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, groupID := range prefixRoute.PeerGroups {
|
||||
seenPeerGroups[groupID] = true
|
||||
|
||||
group := account.GetGroup(groupID)
|
||||
if group == nil {
|
||||
group, ok := peerGroupsMap[groupID]
|
||||
if !ok || group == nil {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
||||
getRouteDescriptor(prefix, domains), groupID,
|
||||
@@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
}
|
||||
|
||||
if peerID != "" {
|
||||
if peerID := checkRoute.Peer; peerID != "" {
|
||||
// check that peerID exists and is not in any route as single peer or part of the group
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeers[peerID]; ok {
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
||||
@@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
|
||||
// check that peerGroupIDs are not in any route peerGroups list
|
||||
for _, groupID := range peerGroupIDs {
|
||||
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
|
||||
|
||||
for _, groupID := range checkRoute.PeerGroups {
|
||||
group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
|
||||
if _, ok := seenPeerGroups[groupID]; ok {
|
||||
return status.Errorf(
|
||||
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
|
||||
@@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
|
||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range group.Peers {
|
||||
if _, ok := seenPeers[id]; ok {
|
||||
peer := account.GetPeer(id)
|
||||
if peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
peer, ok := peersMap[id]
|
||||
if !ok || peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
|
||||
}
|
||||
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s from the group %s already has this route",
|
||||
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
||||
@@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(domains) > 0 && prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
|
||||
if len(domains) == 0 && !prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix")
|
||||
}
|
||||
var newRoute *route.Route
|
||||
var updateAccountPeers bool
|
||||
|
||||
if len(domains) > 0 {
|
||||
prefix = getPlaceholderIP()
|
||||
}
|
||||
|
||||
if peerID != "" && len(peerGroupIDs) != 0 {
|
||||
return nil, status.Errorf(
|
||||
status.InvalidArgument,
|
||||
"peer with ID %s and peers group %s should not be provided at the same time",
|
||||
peerID, peerGroupIDs)
|
||||
}
|
||||
|
||||
var newRoute route.Route
|
||||
newRoute.ID = route.ID(xid.New().String())
|
||||
|
||||
if len(peerGroupIDs) > 0 {
|
||||
err = validateGroups(peerGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
newRoute = &route.Route{
|
||||
ID: route.ID(xid.New().String()),
|
||||
AccountID: accountID,
|
||||
Network: prefix,
|
||||
Domains: domains,
|
||||
KeepRoute: keepRoute,
|
||||
NetID: netID,
|
||||
Description: description,
|
||||
Peer: peerID,
|
||||
PeerGroups: peerGroupIDs,
|
||||
NetworkType: networkType,
|
||||
Masquerade: masquerade,
|
||||
Metric: metric,
|
||||
Enabled: enabled,
|
||||
Groups: groups,
|
||||
AccessControlGroups: accessControlGroupIDs,
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessControlGroupIDs) > 0 {
|
||||
err = validateGroups(accessControlGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
||||
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
err = validateGroups(groups, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRoute.Peer = peerID
|
||||
newRoute.PeerGroups = peerGroupIDs
|
||||
newRoute.Network = prefix
|
||||
newRoute.Domains = domains
|
||||
newRoute.NetworkType = networkType
|
||||
newRoute.Description = description
|
||||
newRoute.NetID = netID
|
||||
newRoute.Masquerade = masquerade
|
||||
newRoute.Metric = metric
|
||||
newRoute.Enabled = enabled
|
||||
newRoute.Groups = groups
|
||||
newRoute.KeepRoute = keepRoute
|
||||
newRoute.AccessControlGroups = accessControlGroupIDs
|
||||
|
||||
if account.Routes == nil {
|
||||
account.Routes = make(map[route.ID]*route.Route)
|
||||
}
|
||||
|
||||
account.Routes[newRoute.ID] = &newRoute
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
return &newRoute, nil
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return newRoute, nil
|
||||
}
|
||||
|
||||
// SaveRoute saves route
|
||||
@@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var oldRoute *route.Route
|
||||
var oldRouteAffectsPeers bool
|
||||
var newRouteAffectsPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var route *route.Route
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||
})
|
||||
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
|
||||
if routeToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||
}
|
||||
@@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
@@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
|
||||
}
|
||||
|
||||
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
|
||||
}
|
||||
|
||||
// validateRouteGroups validates the route groups and returns the validated groups map.
|
||||
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
|
||||
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
|
||||
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(routeToSave.PeerGroups) > 0 {
|
||||
err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(routeToSave.AccessControlGroups) > 0 {
|
||||
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = validateGroups(routeToSave.Groups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute := account.Routes[routeToSave.ID]
|
||||
account.Routes[routeToSave.ID] = routeToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routy := account.Routes[routeID]
|
||||
if routy == nil {
|
||||
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
|
||||
}
|
||||
delete(account.Routes, routeID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
|
||||
if am.isRouteChangeAffectPeers(account, routy) {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
return groupsMap, nil
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
@@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers
|
||||
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool {
|
||||
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
// areRouteChangesAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers.
|
||||
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
|
||||
if route.Peer != "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, r := range accountRoutes {
|
||||
dynamic := r.IsDynamic()
|
||||
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
|
||||
!dynamic && r.Network.String() == prefix.String() {
|
||||
routes = append(routes, r)
|
||||
}
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (store.Store, error) {
|
||||
@@ -1305,7 +1305,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -227,3 +227,7 @@ func NewUserRoleNotFoundError(role string) error {
|
||||
func NewOperationNotFoundError(operation operations.Operation) error {
|
||||
return Errorf(NotFound, "operation: %s not found", operation)
|
||||
}
|
||||
|
||||
func NewRouteNotFoundError(routeID string) error {
|
||||
return Errorf(NotFound, "route: %s not found", routeID)
|
||||
}
|
||||
|
||||
@@ -23,8 +23,6 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -34,6 +32,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -1185,7 +1184,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
|
||||
for _, account := range fileStore.GetAllAccounts(ctx) {
|
||||
_, err = account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
if err := account.AddAllGroup(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db, lockStrength, accountID)
|
||||
var routes []*route.Route
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&routes, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get routes from store")
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
||||
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID)
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
|
||||
var route *route.Route
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&route, accountAndIDQueryCondition, accountID, routeID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewRouteNotFoundError(routeID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get route from store")
|
||||
}
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// SaveRoute saves a route to the database.
|
||||
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save route to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes a route from the database.
|
||||
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete route from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewRouteNotFoundError(routeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
@@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
tx := db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var record []T
|
||||
|
||||
result := tx.Find(&record, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// getRecordByID retrieves a record by its ID and account ID from the database.
|
||||
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
|
||||
tx := db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var record T
|
||||
|
||||
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// SaveDNSSettings saves the DNS settings to the store.
|
||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
|
||||
@@ -19,21 +19,17 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
|
||||
@@ -2048,7 +2044,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
if err := acc.AddAllGroup(false); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
@@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8003, len(accountGroups))
|
||||
}
|
||||
func TestSqlStore_GetAccountRoutes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve routes by existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "non-existing account ID",
|
||||
accountID: "nonexistent",
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty account ID",
|
||||
accountID: "",
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routes, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetRouteByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
routeID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing route",
|
||||
routeID: "ct03t427qv97vmtmglog",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing route",
|
||||
routeID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty route ID",
|
||||
routeID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, route)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, route)
|
||||
require.Equal(t, tt.routeID, string(route.ID))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveRoute(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
route := &route2.Route{
|
||||
ID: "route-id",
|
||||
AccountID: accountID,
|
||||
Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||
NetID: "netID",
|
||||
PeerGroups: []string{"routeA"},
|
||||
NetworkType: route2.IPv4Network,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupA"},
|
||||
AccessControlGroups: []string{},
|
||||
}
|
||||
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
|
||||
require.NoError(t, err)
|
||||
|
||||
saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, route, saveRoute)
|
||||
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteRoute(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
routeID := "ct03t427qv97vmtmglog"
|
||||
|
||||
err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID)
|
||||
require.NoError(t, err)
|
||||
|
||||
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, route)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountMeta(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
|
||||
@@ -145,7 +145,9 @@ type Store interface {
|
||||
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
||||
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
||||
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
|
||||
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
@@ -389,7 +391,7 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
|
||||
|
||||
_, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
if err := account.AddAllGroup(false); err != nil {
|
||||
return err
|
||||
}
|
||||
shouldSave = true
|
||||
|
||||
@@ -18,6 +18,10 @@ type UpdateChannelMetrics struct {
|
||||
getAllConnectedPeersDurationMicro metric.Int64Histogram
|
||||
getAllConnectedPeers metric.Int64Histogram
|
||||
hasChannelDurationMicro metric.Int64Histogram
|
||||
calcPostureChecksDurationMicro metric.Int64Histogram
|
||||
calcPeerNetworkMapDurationMs metric.Int64Histogram
|
||||
mergeNetworkMapDurationMicro metric.Int64Histogram
|
||||
toSyncResponseDurationMicro metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -89,6 +93,38 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
calcPostureChecksDurationMicro, err := meter.Int64Histogram("management.updatechannel.calc.posturechecks.duration.micro",
|
||||
metric.WithUnit("microseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to get the posture checks for a peer"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
calcPeerNetworkMapDurationMs, err := meter.Int64Histogram("management.updatechannel.calc.networkmap.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to calculate the network map for a peer"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mergeNetworkMapDurationMicro, err := meter.Int64Histogram("management.updatechannel.merge.networkmap.duration.micro",
|
||||
metric.WithUnit("microseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to merge the network maps for a peer"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toSyncResponseDurationMicro, err := meter.Int64Histogram("management.updatechannel.tosyncresponse.duration.micro",
|
||||
metric.WithUnit("microseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to convert the network map to sync response"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UpdateChannelMetrics{
|
||||
createChannelDurationMicro: createChannelDurationMicro,
|
||||
closeChannelDurationMicro: closeChannelDurationMicro,
|
||||
@@ -98,6 +134,10 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro,
|
||||
getAllConnectedPeers: getAllConnectedPeers,
|
||||
hasChannelDurationMicro: hasChannelDurationMicro,
|
||||
calcPostureChecksDurationMicro: calcPostureChecksDurationMicro,
|
||||
calcPeerNetworkMapDurationMs: calcPeerNetworkMapDurationMs,
|
||||
mergeNetworkMapDurationMicro: mergeNetworkMapDurationMicro,
|
||||
toSyncResponseDurationMicro: toSyncResponseDurationMicro,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@@ -137,3 +177,19 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration
|
||||
func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) {
|
||||
metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountCalcPostureChecksDuration(duration time.Duration) {
|
||||
metrics.calcPostureChecksDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountCalcPeerNetworkMapDuration(duration time.Duration) {
|
||||
metrics.calcPeerNetworkMapDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountMergeNetworkMapDuration(duration time.Duration) {
|
||||
metrics.mergeNetworkMapDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountToSyncResponseDuration(duration time.Duration) {
|
||||
metrics.toSyncResponseDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
@@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465
|
||||
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
|
||||
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
|
||||
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
|
||||
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
@@ -36,6 +36,9 @@ const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
|
||||
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
|
||||
firewallRuleMinPortRangesVer = "0.48.0"
|
||||
)
|
||||
|
||||
type LookupMap map[string]struct{}
|
||||
@@ -248,7 +251,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap)
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
@@ -961,8 +964,9 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
@@ -973,8 +977,8 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
|
||||
continue
|
||||
}
|
||||
|
||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
@@ -1003,7 +1007,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
|
||||
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||
// generated. The accumulator function returns the result of all the generator calls.
|
||||
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
rulesExists := make(map[string]struct{})
|
||||
peersExists := make(map[string]struct{})
|
||||
rules := make([]*FirewallRule, 0)
|
||||
@@ -1051,17 +1055,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
continue
|
||||
}
|
||||
|
||||
for _, port := range rule.Ports {
|
||||
pr := fr // clone rule and add set new port
|
||||
pr.Port = port
|
||||
rules = append(rules, &pr)
|
||||
}
|
||||
|
||||
for _, portRange := range rule.PortRanges {
|
||||
pr := fr
|
||||
pr.PortRange = portRange
|
||||
rules = append(rules, &pr)
|
||||
}
|
||||
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
}
|
||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
return peers, rules
|
||||
@@ -1552,7 +1546,7 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st
|
||||
}
|
||||
|
||||
// AddAllGroup to account object if it doesn't exist
|
||||
func (a *Account) AddAllGroup() error {
|
||||
func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
if len(a.Groups) == 0 {
|
||||
allGroup := &Group{
|
||||
ID: xid.New().String(),
|
||||
@@ -1564,6 +1558,10 @@ func (a *Account) AddAllGroup() error {
|
||||
}
|
||||
a.Groups = map[string]*Group{allGroup.ID: allGroup}
|
||||
|
||||
if disableDefaultPolicy {
|
||||
return nil
|
||||
}
|
||||
|
||||
id := xid.New().String()
|
||||
|
||||
defaultPolicy := &Policy{
|
||||
@@ -1590,3 +1588,45 @@ func (a *Account) AddAllGroup() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
var expanded []*FirewallRule
|
||||
|
||||
if len(rule.Ports) > 0 {
|
||||
for _, port := range rule.Ports {
|
||||
fr := base
|
||||
fr.Port = port
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
return expanded
|
||||
}
|
||||
|
||||
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
|
||||
for _, portRange := range rule.PortRanges {
|
||||
fr := base
|
||||
|
||||
if supportPortRanges {
|
||||
fr.PortRange = portRange
|
||||
} else {
|
||||
// Peer doesn't support port ranges, only allow single-port ranges
|
||||
if portRange.Start != portRange.End {
|
||||
continue
|
||||
}
|
||||
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
|
||||
}
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
// peerSupportsPortRanges checks if the peer version supports port ranges.
|
||||
func peerSupportsPortRanges(peerVer string) bool {
|
||||
if strings.Contains(peerVer, "dev") {
|
||||
return true
|
||||
}
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
return err == nil && meetMinVer
|
||||
}
|
||||
|
||||
@@ -53,6 +53,9 @@ type Config struct {
|
||||
StoreConfig StoreConfig
|
||||
|
||||
ReverseProxy ReverseProxy
|
||||
|
||||
// disable default all-to-all policy
|
||||
DisableDefaultPolicy bool
|
||||
}
|
||||
|
||||
// GetAuthAudiences returns the audience from the http config and device authorization flow config
|
||||
|
||||
@@ -76,7 +76,6 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule
|
||||
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
|
||||
} else {
|
||||
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
|
||||
|
||||
}
|
||||
|
||||
// TODO: generate IPv6 rules for dynamic routes
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = s.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -103,7 +103,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: false,
|
||||
@@ -131,7 +131,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: true,
|
||||
@@ -163,7 +163,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -188,7 +188,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -213,7 +213,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
@@ -256,7 +256,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -296,7 +296,7 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -406,7 +406,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -453,7 +453,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -501,7 +501,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -532,7 +532,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -639,7 +639,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = tt.serviceUser
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
@@ -678,7 +678,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -705,7 +705,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -792,7 +792,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -952,7 +952,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -988,7 +988,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
|
||||
|
||||
@@ -1030,7 +1030,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
externalUser := &types.User{
|
||||
Id: "externalUser",
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1098,7 +1098,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1132,7 +1132,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1499,7 +1499,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "")
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", false)
|
||||
targetId := "user2"
|
||||
account1.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
@@ -1508,7 +1508,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "")
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", false)
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account2))
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
@@ -1535,7 +1535,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "")
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", false)
|
||||
account1.Settings.RegularUsersViewBlocked = false
|
||||
account1.Users["blocked-user"] = &types.User{
|
||||
Id: "blocked-user",
|
||||
@@ -1557,7 +1557,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, store.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "")
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", false)
|
||||
account2.Users["settings-blocked-user"] = &types.User{
|
||||
Id: "settings-blocked-user",
|
||||
Role: types.UserRoleUser,
|
||||
|
||||
@@ -130,7 +130,7 @@ repo_gpgcheck=1
|
||||
EOF
|
||||
}
|
||||
|
||||
add_aur_repo() {
|
||||
install_aur_package() {
|
||||
INSTALL_PKGS="git base-devel go"
|
||||
REMOVE_PKGS=""
|
||||
|
||||
@@ -154,8 +154,10 @@ add_aur_repo() {
|
||||
cd netbird-ui && makepkg -sri --noconfirm
|
||||
fi
|
||||
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
if [ -n "$REMOVE_PKGS" ]; then
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
fi
|
||||
}
|
||||
|
||||
prepare_tun_module() {
|
||||
@@ -277,7 +279,9 @@ install_netbird() {
|
||||
;;
|
||||
pacman)
|
||||
${SUDO} pacman -Syy
|
||||
add_aur_repo
|
||||
install_aur_package
|
||||
# in-line with the docs at https://wiki.archlinux.org/title/Netbird
|
||||
${SUDO} systemctl enable --now netbird@main.service
|
||||
;;
|
||||
pkg)
|
||||
# Check if the package is already installed
|
||||
@@ -494,4 +498,4 @@ case "$UPDATE_FLAG" in
|
||||
;;
|
||||
*)
|
||||
install_netbird
|
||||
esac
|
||||
esac
|
||||
|
||||
@@ -21,6 +21,7 @@ var (
|
||||
// Update fetch the version info periodically and notify the onUpdateListener in case the UI version or the
|
||||
// daemon version are deprecated
|
||||
type Update struct {
|
||||
httpAgent string
|
||||
uiVersion *goversion.Version
|
||||
daemonVersion *goversion.Version
|
||||
latestAvailable *goversion.Version
|
||||
@@ -34,7 +35,7 @@ type Update struct {
|
||||
}
|
||||
|
||||
// NewUpdate instantiate Update and start to fetch the new version information
|
||||
func NewUpdate() *Update {
|
||||
func NewUpdate(httpAgent string) *Update {
|
||||
currentVersion, err := goversion.NewVersion(version)
|
||||
if err != nil {
|
||||
currentVersion, _ = goversion.NewVersion("0.0.0")
|
||||
@@ -43,6 +44,7 @@ func NewUpdate() *Update {
|
||||
latestAvailable, _ := goversion.NewVersion("0.0.0")
|
||||
|
||||
u := &Update{
|
||||
httpAgent: httpAgent,
|
||||
latestAvailable: latestAvailable,
|
||||
uiVersion: currentVersion,
|
||||
fetchTicker: time.NewTicker(fetchPeriod),
|
||||
@@ -112,7 +114,15 @@ func (u *Update) startFetcher() {
|
||||
func (u *Update) fetchVersion() bool {
|
||||
log.Debugf("fetching version info from %s", versionURL)
|
||||
|
||||
resp, err := http.Get(versionURL)
|
||||
req, err := http.NewRequest("GET", versionURL, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create request for version info: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", u.httpAgent)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to fetch version info: %s", err)
|
||||
return false
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const httpAgent = "pkg/test"
|
||||
|
||||
func TestNewUpdate(t *testing.T) {
|
||||
version = "1.0.0"
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -21,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate()
|
||||
u := NewUpdate(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
@@ -46,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate()
|
||||
u := NewUpdate(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
@@ -71,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate()
|
||||
u := NewUpdate(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
|
||||
Reference in New Issue
Block a user