mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Compare commits
32 Commits
tshoot/win
...
feature/op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
515ce9e3af | ||
|
|
89383b7f01 | ||
|
|
db34162733 | ||
|
|
bd761e2177 | ||
|
|
4e1b95a4c6 | ||
|
|
05993af7bf | ||
|
|
9d1cb00570 | ||
|
|
543731df45 | ||
|
|
e6628ec231 | ||
|
|
41d4dd2aff | ||
|
|
30bed57711 | ||
|
|
6960b68322 | ||
|
|
3b3aa18148 | ||
|
|
93045f3e3a | ||
|
|
fd3c1dea8e | ||
|
|
48aff7a26e | ||
|
|
83dfe8e3a3 | ||
|
|
38e10af2d9 | ||
|
|
99854a126a | ||
|
|
a75f982fcd | ||
|
|
e7a6483912 | ||
|
|
30ede299b8 | ||
|
|
e3b76448f3 | ||
|
|
e0de86d6c9 | ||
|
|
5204d07811 | ||
|
|
5ea24ba56e | ||
|
|
d30cf8706a | ||
|
|
15a2feb723 | ||
|
|
91b2f9fc51 | ||
|
|
76702c8a09 | ||
|
|
061f673a4f | ||
|
|
9505805313 |
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -33,6 +33,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
- name: Check for duplicate constants
|
||||||
|
if: matrix.os == 'ubuntu-latest'
|
||||||
|
run: |
|
||||||
|
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -794,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
|||||||
FQDN: offlinePeer.GetFqdn(),
|
FQDN: offlinePeer.GetFqdn(),
|
||||||
ConnStatus: peer.StatusDisconnected,
|
ConnStatus: peer.StatusDisconnected,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||||
|
|||||||
@@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.agent, err = ice.NewAgent(agentConfig)
|
conn.agent, err = ice.NewAgent(agentConfig)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -285,6 +284,7 @@ func (conn *Conn) Open() error {
|
|||||||
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -344,6 +344,7 @@ func (conn *Conn) Open() error {
|
|||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -465,9 +466,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
LocalIceCandidateType: pair.Local.Type().String(),
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
|
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||||
Direct: !isRelayCandidate(pair.Local),
|
Direct: !isRelayCandidate(pair.Local),
|
||||||
RosenpassEnabled: rosenpassEnabled,
|
RosenpassEnabled: rosenpassEnabled,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||||
peerState.Relayed = true
|
peerState.Relayed = true
|
||||||
@@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error {
|
|||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
// State contains the latest state of a peer
|
// State contains the latest state of a peer
|
||||||
type State struct {
|
type State struct {
|
||||||
|
Mux *sync.RWMutex
|
||||||
IP string
|
IP string
|
||||||
PubKey string
|
PubKey string
|
||||||
FQDN string
|
FQDN string
|
||||||
@@ -30,7 +31,38 @@ type State struct {
|
|||||||
BytesRx int64
|
BytesRx int64
|
||||||
Latency time.Duration
|
Latency time.Duration
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
Routes map[string]struct{}
|
routes map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRoute add a single route to routes map
|
||||||
|
func (s *State) AddRoute(network string) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
if s.routes == nil {
|
||||||
|
s.routes = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
s.routes[network] = struct{}{}
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRoutes set state routes
|
||||||
|
func (s *State) SetRoutes(routes map[string]struct{}) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
s.routes = routes
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRoute removes a route from the network amp
|
||||||
|
func (s *State) DeleteRoute(network string) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
delete(s.routes, network)
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutes return routes map
|
||||||
|
func (s *State) GetRoutes() map[string]struct{} {
|
||||||
|
s.Mux.RLock()
|
||||||
|
defer s.Mux.RUnlock()
|
||||||
|
return s.routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalPeerState contains the latest state of the local peer
|
// LocalPeerState contains the latest state of the local peer
|
||||||
@@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
|||||||
PubKey: peerPubKey,
|
PubKey: peerPubKey,
|
||||||
ConnStatus: StatusDisconnected,
|
ConnStatus: StatusDisconnected,
|
||||||
FQDN: fqdn,
|
FQDN: fqdn,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
d.peerListChangedForNotification = true
|
d.peerListChangedForNotification = true
|
||||||
return nil
|
return nil
|
||||||
@@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
peerState.IP = receivedState.IP
|
peerState.IP = receivedState.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
if receivedState.Routes != nil {
|
if receivedState.GetRoutes() != nil {
|
||||||
peerState.Routes = receivedState.Routes
|
peerState.SetRoutes(receivedState.GetRoutes())
|
||||||
}
|
}
|
||||||
|
|
||||||
skipNotification := shouldSkipNotify(receivedState, peerState)
|
skipNotification := shouldSkipNotify(receivedState, peerState)
|
||||||
@@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
|
|||||||
s, ok := gstatus.FromError(d.managementError)
|
s, ok := gstatus.FromError(d.managementError)
|
||||||
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
return true
|
return true
|
||||||
|
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package peer
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
|||||||
return fmt.Errorf("get peer state: %v", err)
|
return fmt.Errorf("get peer state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(state.Routes, c.network.String())
|
state.DeleteRoute(c.network.String())
|
||||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
}
|
}
|
||||||
@@ -268,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get peer state: %v", err)
|
log.Errorf("Failed to get peer state: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if state.Routes == nil {
|
state.AddRoute(c.network.String())
|
||||||
state.Routes = map[string]struct{}{}
|
|
||||||
}
|
|
||||||
state.Routes[c.network.String()] = struct{}{}
|
|
||||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -68,6 +69,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
|
|
||||||
// Init sets up the routing
|
// Init sets up the routing
|
||||||
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
if nbnet.CustomRoutingDisabled() {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
if err := cleanupRouting(); err != nil {
|
if err := cleanupRouting(); err != nil {
|
||||||
log.Warnf("Failed cleaning up routing: %v", err)
|
log.Warnf("Failed cleaning up routing: %v", err)
|
||||||
}
|
}
|
||||||
@@ -99,11 +104,15 @@ func (m *DefaultManager) Stop() {
|
|||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
m.serverRouter.cleanUp()
|
m.serverRouter.cleanUp()
|
||||||
}
|
}
|
||||||
if err := cleanupRouting(); err != nil {
|
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
if !nbnet.CustomRoutingDisabled() {
|
||||||
} else {
|
if err := cleanupRouting(); err != nil {
|
||||||
log.Info("Routing cleanup complete")
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Info("Routing cleanup complete")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.ctx = nil
|
m.ctx = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,9 +219,11 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||||
switch runtime.GOOS {
|
if !nbnet.CustomRoutingDisabled() {
|
||||||
case "linux", "windows", "darwin":
|
switch runtime.GOOS {
|
||||||
return true
|
case "linux", "windows", "darwin":
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -32,19 +32,31 @@ const (
|
|||||||
rtTablesPath = "/etc/iproute2/rt_tables"
|
rtTablesPath = "/etc/iproute2/rt_tables"
|
||||||
|
|
||||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||||
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||||
|
|
||||||
|
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||||
|
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||||
|
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||||
|
|
||||||
var routeManager = &RouteManager{}
|
var routeManager = &RouteManager{}
|
||||||
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
|
|
||||||
|
// originalSysctl stores the original sysctl values before they are modified
|
||||||
|
var originalSysctl map[string]int
|
||||||
|
|
||||||
|
// determines whether to use the legacy routing setup
|
||||||
|
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||||
|
|
||||||
|
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
|
||||||
|
var sysctlFailed bool
|
||||||
|
|
||||||
type ruleParams struct {
|
type ruleParams struct {
|
||||||
|
priority int
|
||||||
fwmark int
|
fwmark int
|
||||||
tableID int
|
tableID int
|
||||||
family int
|
family int
|
||||||
priority int
|
|
||||||
invert bool
|
invert bool
|
||||||
suppressPrefix int
|
suppressPrefix int
|
||||||
description string
|
description string
|
||||||
@@ -52,10 +64,10 @@ type ruleParams struct {
|
|||||||
|
|
||||||
func getSetupRules() []ruleParams {
|
func getSetupRules() []ruleParams {
|
||||||
return []ruleParams{
|
return []ruleParams{
|
||||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
|
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
|
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
|
||||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
|
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
|
||||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
|
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,8 +81,6 @@ func getSetupRules() []ruleParams {
|
|||||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
//
|
|
||||||
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
|
|
||||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||||
if isLegacy {
|
if isLegacy {
|
||||||
log.Infof("Using legacy routing setup")
|
log.Infof("Using legacy routing setup")
|
||||||
@@ -81,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
|||||||
log.Errorf("Error adding routing table name: %v", err)
|
log.Errorf("Error adding routing table name: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
originalValues, err := setupSysctl(wgIface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error setting up sysctl: %v", err)
|
||||||
|
sysctlFailed = true
|
||||||
|
}
|
||||||
|
originalSysctl = originalValues
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||||
@@ -123,11 +140,17 @@ func cleanupRouting() error {
|
|||||||
|
|
||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
|
if err := removeRule(rule); err != nil {
|
||||||
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := cleanupSysctl(originalSysctl); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
|
||||||
|
}
|
||||||
|
originalSysctl = nil
|
||||||
|
sysctlFailed = false
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
|
|||||||
return genericAddVPNRoute(prefix, intf)
|
return genericAddVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
|
||||||
|
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
|
||||||
|
}
|
||||||
|
|
||||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
// TODO remove this once we have ipv6 support
|
||||||
@@ -336,22 +363,8 @@ func flushRoutes(tableID, family int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func enableIPForwarding() error {
|
||||||
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
_, err := setSysctl(ipv4ForwardingPath, 1, false)
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if it is already enabled
|
|
||||||
// see more: https://github.com/netbirdio/netbird/issues/872
|
|
||||||
if len(bytes) > 0 && bytes[0] == 49 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:gosec
|
|
||||||
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
|
|
||||||
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
||||||
@@ -429,7 +442,7 @@ func addRule(params ruleParams) error {
|
|||||||
rule.Invert = params.invert
|
rule.Invert = params.invert
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
return fmt.Errorf("add routing rule: %w", err)
|
return fmt.Errorf("add routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -446,43 +459,13 @@ func removeRule(params ruleParams) error {
|
|||||||
rule.Priority = params.priority
|
rule.Priority = params.priority
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleDel(rule); err != nil {
|
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
return fmt.Errorf("remove routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeAllRules(params ruleParams) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
done <- ctx.Err()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := removeRule(params); err != nil {
|
|
||||||
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
done <- nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
done <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case err := <-done:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNextHop adds the gateway and device to the route.
|
// addNextHop adds the gateway and device to the route.
|
||||||
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
|
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
@@ -509,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int {
|
|||||||
}
|
}
|
||||||
return netlink.FAMILY_V6
|
return netlink.FAMILY_V6
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupSysctl configures sysctl settings for RP filtering and source validation.
|
||||||
|
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
|
||||||
|
keys := map[string]int{}
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
|
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[srcValidMarkPath] = oldVal
|
||||||
|
}
|
||||||
|
|
||||||
|
oldVal, err = setSysctl(rpFilterPath, 2, true)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[rpFilterPath] = oldVal
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, intf := range interfaces {
|
||||||
|
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||||
|
oldVal, err := setSysctl(i, 2, true)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[i] = oldVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||||
|
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||||
|
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||||
|
currentValue, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||||
|
if err != nil && len(currentValue) > 0 {
|
||||||
|
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||||
|
return currentV, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gosec
|
||||||
|
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||||
|
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||||
|
|
||||||
|
return currentV, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupSysctl(originalSettings map[string]int) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
|
for key, value := range originalSettings {
|
||||||
|
_, err := setSysctl(key, value, false)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
|
|
||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
_, _, err = setupRouting(nil, nil)
|
_, _, err = setupRouting(nil, wgInterface)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, cleanupRouting())
|
assert.NoError(t, cleanupRouting())
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx s
|
|||||||
}
|
}
|
||||||
|
|
||||||
script := fmt.Sprintf(
|
script := fmt.Sprintf(
|
||||||
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`,
|
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
|
||||||
psCmd, addressFamily, destinationPrefix,
|
psCmd, addressFamily, destinationPrefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set the fwmark on the socket.
|
// Set the fwmark on the socket.
|
||||||
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
|
err = nbnet.SetSocketOpt(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
|||||||
BytesRx: peerState.BytesRx,
|
BytesRx: peerState.BytesRx,
|
||||||
BytesTx: peerState.BytesTx,
|
BytesTx: peerState.BytesTx,
|
||||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||||
Routes: maps.Keys(peerState.Routes),
|
Routes: maps.Keys(peerState.GetRoutes()),
|
||||||
Latency: durationpb.New(peerState.Latency),
|
Latency: durationpb.New(peerState.Latency),
|
||||||
}
|
}
|
||||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -60,7 +60,7 @@ require (
|
|||||||
github.com/miekg/dns v1.1.43
|
github.com/miekg/dns v1.1.43
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -383,8 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
|
|||||||
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
||||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgKernelConfigurer struct {
|
type wgKernelConfigurer struct {
|
||||||
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fwmark := nbnet.NetbirdFwmark
|
fwmark := getFwmark()
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
PrivateKey: &key,
|
PrivateKey: &key,
|
||||||
ReplacePeers: true,
|
ReplacePeers: true,
|
||||||
|
|||||||
@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.NetbirdFwmark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -251,7 +251,7 @@ var (
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
|
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1473,7 +1473,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
|
|||||||
// if domain already has a primary account, add regular user
|
// if domain already has a primary account, add regular user
|
||||||
if domainAcc != nil {
|
if domainAcc != nil {
|
||||||
account = domainAcc
|
account = domainAcc
|
||||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1849,6 +1849,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
||||||
|
log.Debugf("validated peers has been invalidated for account %s", accountID)
|
||||||
updatedAccount, err := am.Store.GetAccount(accountID)
|
updatedAccount, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to get account %s: %v", accountID, err)
|
log.Errorf("failed to get account %s: %v", accountID, err)
|
||||||
@@ -1861,9 +1862,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
|||||||
func addAllGroup(account *Account) error {
|
func addAllGroup(account *Account) error {
|
||||||
if len(account.Groups) == 0 {
|
if len(account.Groups) == 0 {
|
||||||
allGroup := &nbgroup.Group{
|
allGroup := &nbgroup.Group{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
Name: "All",
|
Name: "All",
|
||||||
Issued: nbgroup.GroupIssuedAPI,
|
Issued: nbgroup.GroupIssuedAPI,
|
||||||
|
AccountID: account.Id,
|
||||||
}
|
}
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||||
@@ -1907,7 +1909,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
|||||||
routes := make(map[string]*route.Route)
|
routes := make(map[string]*route.Route)
|
||||||
setupKeys := map[string]*SetupKey{}
|
setupKeys := map[string]*SetupKey{}
|
||||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||||
users[userID] = NewOwnerUser(userID)
|
users[userID] = NewOwnerUser(userID, accountID)
|
||||||
dnsSettings := DNSSettings{
|
dnsSettings := DNSSettings{
|
||||||
DisabledManagementGroups: make([]string, 0),
|
DisabledManagementGroups: make([]string, 0),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,133 +11,134 @@ type Code struct {
|
|||||||
Code string
|
Code string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Existing consts must not be changed, as this will break the compatibility with the existing data
|
||||||
const (
|
const (
|
||||||
// PeerAddedByUser indicates that a user added a new peer to the system
|
// PeerAddedByUser indicates that a user added a new peer to the system
|
||||||
PeerAddedByUser Activity = iota
|
PeerAddedByUser Activity = 0
|
||||||
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
|
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
|
||||||
PeerAddedWithSetupKey
|
PeerAddedWithSetupKey Activity = 1
|
||||||
// UserJoined indicates that a new user joined the account
|
// UserJoined indicates that a new user joined the account
|
||||||
UserJoined
|
UserJoined Activity = 2
|
||||||
// UserInvited indicates that a new user was invited to join the account
|
// UserInvited indicates that a new user was invited to join the account
|
||||||
UserInvited
|
UserInvited Activity = 3
|
||||||
// AccountCreated indicates that a new account has been created
|
// AccountCreated indicates that a new account has been created
|
||||||
AccountCreated
|
AccountCreated Activity = 4
|
||||||
// PeerRemovedByUser indicates that a user removed a peer from the system
|
// PeerRemovedByUser indicates that a user removed a peer from the system
|
||||||
PeerRemovedByUser
|
PeerRemovedByUser Activity = 5
|
||||||
// RuleAdded indicates that a user added a new rule
|
// RuleAdded indicates that a user added a new rule
|
||||||
RuleAdded
|
RuleAdded Activity = 6
|
||||||
// RuleUpdated indicates that a user updated a rule
|
// RuleUpdated indicates that a user updated a rule
|
||||||
RuleUpdated
|
RuleUpdated Activity = 7
|
||||||
// RuleRemoved indicates that a user removed a rule
|
// RuleRemoved indicates that a user removed a rule
|
||||||
RuleRemoved
|
RuleRemoved Activity = 8
|
||||||
// PolicyAdded indicates that a user added a new policy
|
// PolicyAdded indicates that a user added a new policy
|
||||||
PolicyAdded
|
PolicyAdded Activity = 9
|
||||||
// PolicyUpdated indicates that a user updated a policy
|
// PolicyUpdated indicates that a user updated a policy
|
||||||
PolicyUpdated
|
PolicyUpdated Activity = 10
|
||||||
// PolicyRemoved indicates that a user removed a policy
|
// PolicyRemoved indicates that a user removed a policy
|
||||||
PolicyRemoved
|
PolicyRemoved Activity = 11
|
||||||
// SetupKeyCreated indicates that a user created a new setup key
|
// SetupKeyCreated indicates that a user created a new setup key
|
||||||
SetupKeyCreated
|
SetupKeyCreated Activity = 12
|
||||||
// SetupKeyUpdated indicates that a user updated a setup key
|
// SetupKeyUpdated indicates that a user updated a setup key
|
||||||
SetupKeyUpdated
|
SetupKeyUpdated Activity = 13
|
||||||
// SetupKeyRevoked indicates that a user revoked a setup key
|
// SetupKeyRevoked indicates that a user revoked a setup key
|
||||||
SetupKeyRevoked
|
SetupKeyRevoked Activity = 14
|
||||||
// SetupKeyOverused indicates that setup key usage exhausted
|
// SetupKeyOverused indicates that setup key usage exhausted
|
||||||
SetupKeyOverused
|
SetupKeyOverused Activity = 15
|
||||||
// GroupCreated indicates that a user created a group
|
// GroupCreated indicates that a user created a group
|
||||||
GroupCreated
|
GroupCreated Activity = 16
|
||||||
// GroupUpdated indicates that a user updated a group
|
// GroupUpdated indicates that a user updated a group
|
||||||
GroupUpdated
|
GroupUpdated Activity = 17
|
||||||
// GroupAddedToPeer indicates that a user added group to a peer
|
// GroupAddedToPeer indicates that a user added group to a peer
|
||||||
GroupAddedToPeer
|
GroupAddedToPeer Activity = 18
|
||||||
// GroupRemovedFromPeer indicates that a user removed peer group
|
// GroupRemovedFromPeer indicates that a user removed peer group
|
||||||
GroupRemovedFromPeer
|
GroupRemovedFromPeer Activity = 19
|
||||||
// GroupAddedToUser indicates that a user added group to a user
|
// GroupAddedToUser indicates that a user added group to a user
|
||||||
GroupAddedToUser
|
GroupAddedToUser Activity = 20
|
||||||
// GroupRemovedFromUser indicates that a user removed a group from a user
|
// GroupRemovedFromUser indicates that a user removed a group from a user
|
||||||
GroupRemovedFromUser
|
GroupRemovedFromUser Activity = 21
|
||||||
// UserRoleUpdated indicates that a user changed the role of a user
|
// UserRoleUpdated indicates that a user changed the role of a user
|
||||||
UserRoleUpdated
|
UserRoleUpdated Activity = 22
|
||||||
// GroupAddedToSetupKey indicates that a user added group to a setup key
|
// GroupAddedToSetupKey indicates that a user added group to a setup key
|
||||||
GroupAddedToSetupKey
|
GroupAddedToSetupKey Activity = 23
|
||||||
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
|
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
|
||||||
GroupRemovedFromSetupKey
|
GroupRemovedFromSetupKey Activity = 24
|
||||||
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
|
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
|
||||||
GroupAddedToDisabledManagementGroups
|
GroupAddedToDisabledManagementGroups Activity = 25
|
||||||
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
|
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
|
||||||
GroupRemovedFromDisabledManagementGroups
|
GroupRemovedFromDisabledManagementGroups Activity = 26
|
||||||
// RouteCreated indicates that a user created a route
|
// RouteCreated indicates that a user created a route
|
||||||
RouteCreated
|
RouteCreated Activity = 27
|
||||||
// RouteRemoved indicates that a user deleted a route
|
// RouteRemoved indicates that a user deleted a route
|
||||||
RouteRemoved
|
RouteRemoved Activity = 28
|
||||||
// RouteUpdated indicates that a user updated a route
|
// RouteUpdated indicates that a user updated a route
|
||||||
RouteUpdated
|
RouteUpdated Activity = 29
|
||||||
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
|
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
|
||||||
PeerSSHEnabled
|
PeerSSHEnabled Activity = 30
|
||||||
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
|
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
|
||||||
PeerSSHDisabled
|
PeerSSHDisabled Activity = 31
|
||||||
// PeerRenamed indicates that a user renamed a peer
|
// PeerRenamed indicates that a user renamed a peer
|
||||||
PeerRenamed
|
PeerRenamed Activity = 32
|
||||||
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
|
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
|
||||||
PeerLoginExpirationEnabled
|
PeerLoginExpirationEnabled Activity = 33
|
||||||
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
|
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
|
||||||
PeerLoginExpirationDisabled
|
PeerLoginExpirationDisabled Activity = 34
|
||||||
// NameserverGroupCreated indicates that a user created a nameservers group
|
// NameserverGroupCreated indicates that a user created a nameservers group
|
||||||
NameserverGroupCreated
|
NameserverGroupCreated Activity = 35
|
||||||
// NameserverGroupDeleted indicates that a user deleted a nameservers group
|
// NameserverGroupDeleted indicates that a user deleted a nameservers group
|
||||||
NameserverGroupDeleted
|
NameserverGroupDeleted Activity = 36
|
||||||
// NameserverGroupUpdated indicates that a user updated a nameservers group
|
// NameserverGroupUpdated indicates that a user updated a nameservers group
|
||||||
NameserverGroupUpdated
|
NameserverGroupUpdated Activity = 37
|
||||||
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
|
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
|
||||||
AccountPeerLoginExpirationEnabled
|
AccountPeerLoginExpirationEnabled Activity = 38
|
||||||
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
|
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
|
||||||
AccountPeerLoginExpirationDisabled
|
AccountPeerLoginExpirationDisabled Activity = 39
|
||||||
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
|
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
|
||||||
AccountPeerLoginExpirationDurationUpdated
|
AccountPeerLoginExpirationDurationUpdated Activity = 40
|
||||||
// PersonalAccessTokenCreated indicates that a user created a personal access token
|
// PersonalAccessTokenCreated indicates that a user created a personal access token
|
||||||
PersonalAccessTokenCreated
|
PersonalAccessTokenCreated Activity = 41
|
||||||
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
||||||
PersonalAccessTokenDeleted
|
PersonalAccessTokenDeleted Activity = 42
|
||||||
// ServiceUserCreated indicates that a user created a service user
|
// ServiceUserCreated indicates that a user created a service user
|
||||||
ServiceUserCreated
|
ServiceUserCreated Activity = 43
|
||||||
// ServiceUserDeleted indicates that a user deleted a service user
|
// ServiceUserDeleted indicates that a user deleted a service user
|
||||||
ServiceUserDeleted
|
ServiceUserDeleted Activity = 44
|
||||||
// UserBlocked indicates that a user blocked another user
|
// UserBlocked indicates that a user blocked another user
|
||||||
UserBlocked
|
UserBlocked Activity = 45
|
||||||
// UserUnblocked indicates that a user unblocked another user
|
// UserUnblocked indicates that a user unblocked another user
|
||||||
UserUnblocked
|
UserUnblocked Activity = 46
|
||||||
// UserDeleted indicates that a user deleted another user
|
// UserDeleted indicates that a user deleted another user
|
||||||
UserDeleted
|
UserDeleted Activity = 47
|
||||||
// GroupDeleted indicates that a user deleted group
|
// GroupDeleted indicates that a user deleted group
|
||||||
GroupDeleted
|
GroupDeleted Activity = 48
|
||||||
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
||||||
UserLoggedInPeer
|
UserLoggedInPeer Activity = 49
|
||||||
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
|
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
|
||||||
PeerLoginExpired
|
PeerLoginExpired Activity = 50
|
||||||
// DashboardLogin indicates that the user logged in to the dashboard
|
// DashboardLogin indicates that the user logged in to the dashboard
|
||||||
DashboardLogin
|
DashboardLogin Activity = 51
|
||||||
// IntegrationCreated indicates that the user created an integration
|
// IntegrationCreated indicates that the user created an integration
|
||||||
IntegrationCreated
|
IntegrationCreated Activity = 52
|
||||||
// IntegrationUpdated indicates that the user updated an integration
|
// IntegrationUpdated indicates that the user updated an integration
|
||||||
IntegrationUpdated
|
IntegrationUpdated Activity = 53
|
||||||
// IntegrationDeleted indicates that the user deleted an integration
|
// IntegrationDeleted indicates that the user deleted an integration
|
||||||
IntegrationDeleted
|
IntegrationDeleted Activity = 54
|
||||||
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
|
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
|
||||||
AccountPeerApprovalEnabled
|
AccountPeerApprovalEnabled Activity = 55
|
||||||
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
|
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
|
||||||
AccountPeerApprovalDisabled
|
AccountPeerApprovalDisabled Activity = 56
|
||||||
// PeerApproved indicates that the peer has been approved
|
// PeerApproved indicates that the peer has been approved
|
||||||
PeerApproved
|
PeerApproved Activity = 57
|
||||||
// PeerApprovalRevoked indicates that the peer approval has been revoked
|
// PeerApprovalRevoked indicates that the peer approval has been revoked
|
||||||
PeerApprovalRevoked
|
PeerApprovalRevoked Activity = 58
|
||||||
// TransferredOwnerRole indicates that the user transferred the owner role of the account
|
// TransferredOwnerRole indicates that the user transferred the owner role of the account
|
||||||
TransferredOwnerRole
|
TransferredOwnerRole Activity = 59
|
||||||
// PostureCheckCreated indicates that the user created a posture check
|
// PostureCheckCreated indicates that the user created a posture check
|
||||||
PostureCheckCreated
|
PostureCheckCreated Activity = 60
|
||||||
// PostureCheckUpdated indicates that the user updated a posture check
|
// PostureCheckUpdated indicates that the user updated a posture check
|
||||||
PostureCheckUpdated
|
PostureCheckUpdated Activity = 61
|
||||||
// PostureCheckDeleted indicates that the user deleted a posture check
|
// PostureCheckDeleted indicates that the user deleted a posture check
|
||||||
PostureCheckDeleted
|
PostureCheckDeleted Activity = 62
|
||||||
)
|
)
|
||||||
|
|
||||||
var activityMap = map[Activity]Code{
|
var activityMap = map[Activity]Code{
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
|
|||||||
|
|
||||||
func TestAccounts_AccountsHandler(t *testing.T) {
|
func TestAccounts_AccountsHandler(t *testing.T) {
|
||||||
accountID := "test_account"
|
accountID := "test_account"
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
|
|
||||||
sr := func(v string) *string { return &v }
|
sr := func(v string) *string { return &v }
|
||||||
br := func(v bool) *bool { return &v }
|
br := func(v bool) *bool { return &v }
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
|
|||||||
Id: testDNSSettingsAccountID,
|
Id: testDNSSettingsAccountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
testDNSSettingsUserID: server.NewAdminUser("test_user"),
|
testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
|
||||||
},
|
},
|
||||||
DNSSettings: baseExistingDNSSettings,
|
DNSSettings: baseExistingDNSSettings,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
accountID := "test_account"
|
accountID := "test_account"
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
events := generateEvents(accountID, adminUser.Id)
|
events := generateEvents(accountID, adminUser.Id)
|
||||||
handler := initEventsTestData(accountID, adminUser, events...)
|
handler := initEventsTestData(accountID, adminUser, events...)
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
|||||||
return &GeolocationsHandler{
|
return &GeolocationsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
user := server.NewAdminUser("test_user", "account_id")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
|
|||||||
Name: "Group",
|
Name: "Group",
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
p := initGroupTestData(adminUser, group)
|
p := initGroupTestData(adminUser, group)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
p := initGroupTestData(adminUser)
|
p := initGroupTestData(adminUser)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
p := initGroupTestData(adminUser)
|
p := initGroupTestData(adminUser)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
s "github.com/netbirdio/netbird/management/server"
|
s "github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
@@ -38,7 +39,7 @@ type emptyObject struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
|
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||||
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
|||||||
AuthCfg: authCfg,
|
AuthCfg: authCfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
|
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
|
||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
|
|||||||
Id: testNSGroupAccountID,
|
Id: testNSGroupAccountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": server.NewAdminUser("test_user"),
|
"test_user": server.NewAdminUser("test_user", "account_id"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
return "netbird.selfhosted"
|
return "netbird.selfhosted"
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
user := server.NewAdminUser("test_user", "account_id")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
user := server.NewAdminUser("test_user", "account_id")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
|||||||
return accountPostureChecks, nil
|
return accountPostureChecks, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
user := server.NewAdminUser("test_user", "account_id")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": server.NewAdminUser("test_user"),
|
"test_user": server.NewAdminUser("test_user", "account_id"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
defaultSetupKey := server.GenerateDefaultSetupKey()
|
defaultSetupKey := server.GenerateDefaultSetupKey()
|
||||||
defaultSetupKey.Id = existingSetupKeyID
|
defaultSetupKey.Id = existingSetupKeyID
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
|
|
||||||
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
||||||
server.SetupKeyUnlimitedUsage, true)
|
server.SetupKeyUnlimitedUsage, true)
|
||||||
|
|||||||
@@ -551,8 +551,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
|
|||||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||||
}
|
}
|
||||||
|
|
||||||
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||||
if requiresApproval {
|
if peerNotValid {
|
||||||
emptyMap := &NetworkMap{
|
emptyMap := &NetworkMap{
|
||||||
Network: account.Network.Copy(),
|
Network: account.Network.Copy(),
|
||||||
}
|
}
|
||||||
@@ -563,11 +563,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
|
|||||||
am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
}
|
}
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
validPeersMap, err := am.GetValidatedPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
|
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginPeer logs in or registers a peer.
|
// LoginPeer logs in or registers a peer.
|
||||||
|
|||||||
@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
select {
|
select {
|
||||||
case <-cancel:
|
case <-cancel:
|
||||||
log.Debugf("scheduled job %s was canceled, stop timer", ID)
|
log.Tracef("scheduled job %s was canceled, stop timer", ID)
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
log.Debugf("time to do a scheduled job %s", ID)
|
log.Tracef("time to do a scheduled job %s", ID)
|
||||||
}
|
}
|
||||||
runIn, reschedule := job()
|
runIn, reschedule := job()
|
||||||
if !reschedule {
|
if !reschedule {
|
||||||
wm.mu.Lock()
|
wm.mu.Lock()
|
||||||
defer wm.mu.Unlock()
|
defer wm.mu.Unlock()
|
||||||
delete(wm.jobs, ID)
|
delete(wm.jobs, ID)
|
||||||
log.Debugf("job %s is not scheduled to run again", ID)
|
log.Tracef("job %s is not scheduled to run again", ID)
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
|
|||||||
ticker.Reset(runIn)
|
ticker.Reset(runIn)
|
||||||
}
|
}
|
||||||
case <-cancel:
|
case <-cancel:
|
||||||
log.Debugf("job %s was canceled, stopping timer", ID)
|
log.Tracef("job %s was canceled, stopping timer", ID)
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
|
|||||||
return unlock
|
return unlock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
|
||||||
|
// Get the reflect.Value of the records slice
|
||||||
|
v := reflect.ValueOf(records)
|
||||||
|
if v.Kind() != reflect.Slice {
|
||||||
|
return fmt.Errorf("provided input is not a slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert records in batches
|
||||||
|
for i := 0; i < v.Len(); i += batchSize {
|
||||||
|
end := i + batchSize
|
||||||
|
if end > v.Len() {
|
||||||
|
end = v.Len()
|
||||||
|
}
|
||||||
|
// Use reflect.Slice to get a slice of the records for the current batch
|
||||||
|
batch := v.Slice(i, end).Interface()
|
||||||
|
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqliteStore) SaveAccount(account *Account) error {
|
func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
for _, key := range account.SetupKeys {
|
// operate over a fresh copy as we will modify its fields
|
||||||
account.SetupKeysG = append(account.SetupKeysG, *key)
|
accCopy := account.Copy()
|
||||||
|
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
|
||||||
|
for _, key := range accCopy.SetupKeys {
|
||||||
|
//we need an explicit reference to the account for gorm
|
||||||
|
key.AccountID = accCopy.Id
|
||||||
|
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, peer := range account.Peers {
|
accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
|
||||||
|
for id, peer := range accCopy.Peers {
|
||||||
peer.ID = id
|
peer.ID = id
|
||||||
account.PeersG = append(account.PeersG, *peer)
|
//we need an explicit reference to the account for gorm
|
||||||
|
peer.AccountID = accCopy.Id
|
||||||
|
accCopy.PeersG = append(accCopy.PeersG, *peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, user := range account.Users {
|
accCopy.UsersG = make([]User, 0, len(accCopy.Users))
|
||||||
|
for id, user := range accCopy.Users {
|
||||||
user.Id = id
|
user.Id = id
|
||||||
|
//we need an explicit reference to the account for gorm
|
||||||
|
user.AccountID = accCopy.Id
|
||||||
|
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
|
||||||
for id, pat := range user.PATs {
|
for id, pat := range user.PATs {
|
||||||
pat.ID = id
|
pat.ID = id
|
||||||
user.PATsG = append(user.PATsG, *pat)
|
user.PATsG = append(user.PATsG, *pat)
|
||||||
}
|
}
|
||||||
account.UsersG = append(account.UsersG, *user)
|
accCopy.UsersG = append(accCopy.UsersG, *user)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, group := range account.Groups {
|
accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
|
||||||
|
for id, group := range accCopy.Groups {
|
||||||
group.ID = id
|
group.ID = id
|
||||||
account.GroupsG = append(account.GroupsG, *group)
|
//we need an explicit reference to the account for gorm
|
||||||
|
group.AccountID = accCopy.Id
|
||||||
|
accCopy.GroupsG = append(accCopy.GroupsG, *group)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, route := range account.Routes {
|
accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
|
||||||
|
for id, route := range accCopy.Routes {
|
||||||
route.ID = id
|
route.ID = id
|
||||||
account.RoutesG = append(account.RoutesG, *route)
|
//we need an explicit reference to the account for gorm
|
||||||
|
route.AccountID = accCopy.Id
|
||||||
|
accCopy.RoutesG = append(accCopy.RoutesG, *route)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, ns := range account.NameServerGroups {
|
accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
|
||||||
|
for id, ns := range accCopy.NameServerGroups {
|
||||||
ns.ID = id
|
ns.ID = id
|
||||||
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
|
//we need an explicit reference to the account for gorm
|
||||||
|
ns.AccountID = accCopy.Id
|
||||||
|
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
|
result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tx.Select(clause.Associations).Delete(account)
|
result = tx.Select(clause.Associations).Delete(accCopy)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tx.
|
result = tx.
|
||||||
Session(&gorm.Session{FullSaveAssociations: true}).
|
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
|
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||||
|
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
|
||||||
|
Create(accCopy)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
const batchSize = 500
|
||||||
|
err := batchInsert(accCopy.PeersG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = batchInsert(accCopy.UsersG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = batchInsert(accCopy.GroupsG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = batchInsert(accCopy.RoutesG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
|
||||||
})
|
})
|
||||||
|
|
||||||
took := time.Since(start)
|
took := time.Since(start)
|
||||||
if s.metrics != nil {
|
if s.metrics != nil {
|
||||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||||
}
|
}
|
||||||
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
|
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
|
|||||||
func (s *SqliteStore) DeleteAccount(account *Account) error {
|
func (s *SqliteStore) DeleteAccount(account *Account) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
|
account.UsersG = make([]User, 0, len(account.Users))
|
||||||
|
for id, user := range account.Users {
|
||||||
|
user.Id = id
|
||||||
|
//we need an explicit reference to an account as it is missing for some reason
|
||||||
|
user.AccountID = account.Id
|
||||||
|
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
|
||||||
|
for id, pat := range user.PATs {
|
||||||
|
pat.ID = id
|
||||||
|
user.PATsG = append(user.PATsG, *pat)
|
||||||
|
}
|
||||||
|
account.UsersG = append(account.UsersG, *user)
|
||||||
|
}
|
||||||
|
|
||||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
|
|||||||
@@ -2,7 +2,12 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
route2 "github.com/netbirdio/netbird/route"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) {
|
|||||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func TestSqlite_SaveAccount_Large(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStore(t)
|
||||||
|
|
||||||
|
account := newAccountWithId("account_id", "testuser", "")
|
||||||
|
groupALL, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
setupKey := GenerateDefaultSetupKey()
|
||||||
|
account.SetupKeys[setupKey.Key] = setupKey
|
||||||
|
const numPerAccount = 2000
|
||||||
|
for n := 0; n < numPerAccount; n++ {
|
||||||
|
netIP := randomIPv4()
|
||||||
|
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
|
||||||
|
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
ID: peerID,
|
||||||
|
Key: peerID,
|
||||||
|
SetupKey: "",
|
||||||
|
IP: netIP,
|
||||||
|
Name: peerID,
|
||||||
|
DNSLabel: peerID,
|
||||||
|
UserID: userID,
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
SSHEnabled: false,
|
||||||
|
}
|
||||||
|
account.Peers[peerID] = peer
|
||||||
|
group, _ := account.GetGroupAll()
|
||||||
|
group.Peers = append(group.Peers, peerID)
|
||||||
|
user := &User{
|
||||||
|
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
|
||||||
|
AccountID: account.Id,
|
||||||
|
}
|
||||||
|
account.Users[user.Id] = user
|
||||||
|
route := &route2.Route{
|
||||||
|
ID: fmt.Sprintf("network-id-%d", n),
|
||||||
|
Description: "base route",
|
||||||
|
NetID: fmt.Sprintf("network-id-%d", n),
|
||||||
|
Network: netip.MustParsePrefix(netIP.String() + "/24"),
|
||||||
|
NetworkType: route2.IPv4Network,
|
||||||
|
Metric: 9999,
|
||||||
|
Masquerade: false,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{groupALL.ID},
|
||||||
|
}
|
||||||
|
account.Routes[route.ID] = route
|
||||||
|
|
||||||
|
group = &nbgroup.Group{
|
||||||
|
ID: fmt.Sprintf("group-id-%d", n),
|
||||||
|
AccountID: account.Id,
|
||||||
|
Name: fmt.Sprintf("group-id-%d", n),
|
||||||
|
Issued: "api",
|
||||||
|
Peers: nil,
|
||||||
|
}
|
||||||
|
account.Groups[group.ID] = group
|
||||||
|
|
||||||
|
nameserver := &nbdns.NameServerGroup{
|
||||||
|
ID: fmt.Sprintf("nameserver-id-%d", n),
|
||||||
|
AccountID: account.Id,
|
||||||
|
Name: fmt.Sprintf("nameserver-id-%d", n),
|
||||||
|
Description: "",
|
||||||
|
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
|
||||||
|
Groups: []string{group.ID},
|
||||||
|
Primary: false,
|
||||||
|
Domains: nil,
|
||||||
|
Enabled: false,
|
||||||
|
SearchDomainsEnabled: false,
|
||||||
|
}
|
||||||
|
account.NameServerGroups[nameserver.ID] = nameserver
|
||||||
|
|
||||||
|
setupKey := GenerateDefaultSetupKey()
|
||||||
|
account.SetupKeys[setupKey.Key] = setupKey
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.SaveAccount(account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if len(store.GetAllAccounts()) != 1 {
|
||||||
|
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := store.GetAccount(account.Id)
|
||||||
|
if a == nil {
|
||||||
|
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Policies) != 1 {
|
||||||
|
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Policies[0].Rules) != 1 {
|
||||||
|
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Peers) != numPerAccount {
|
||||||
|
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount, len(a.Peers))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Users) != numPerAccount+1 {
|
||||||
|
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount+1, len(a.Users))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Routes) != numPerAccount {
|
||||||
|
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount, len(a.Routes))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.NameServerGroups) != numPerAccount {
|
||||||
|
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount, len(a.NameServerGroups))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.NameServerGroups) != numPerAccount {
|
||||||
|
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount, len(a.NameServerGroups))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
|
||||||
|
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount+1, len(a.SetupKeys))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqlite_SaveAccount(t *testing.T) {
|
func TestSqlite_SaveAccount(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
@@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
|||||||
Name: "peer name",
|
Name: "peer name",
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||||
}
|
}
|
||||||
|
admin := account.Users["testuser"]
|
||||||
|
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||||
|
ID: "testtoken",
|
||||||
|
Name: "test token",
|
||||||
|
HashedToken: "hashed token",
|
||||||
|
}}
|
||||||
|
|
||||||
err := store.SaveAccount(account)
|
err := store.SaveAccount(account)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
|||||||
store := newSqliteStore(t)
|
store := newSqliteStore(t)
|
||||||
|
|
||||||
testUserID := "testuser"
|
testUserID := "testuser"
|
||||||
user := NewAdminUser(testUserID)
|
user := NewAdminUser(testUserID, "account_id")
|
||||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||||
ID: "testtoken",
|
ID: "testtoken",
|
||||||
Name: "test token",
|
Name: "test token",
|
||||||
@@ -393,3 +539,12 @@ func newAccount(store Store, id int) error {
|
|||||||
|
|
||||||
return store.SaveAccount(account)
|
return store.SaveAccount(account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randomIPv4() net.IP {
|
||||||
|
rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
b := make([]byte, 4)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = byte(rand.Intn(256))
|
||||||
|
}
|
||||||
|
return net.IP(b)
|
||||||
|
}
|
||||||
|
|||||||
@@ -180,9 +180,11 @@ func (u *User) Copy() *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewUser creates a new user
|
// NewUser creates a new user
|
||||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
|
||||||
|
accountID string) *User {
|
||||||
return &User{
|
return &User{
|
||||||
Id: id,
|
Id: ID,
|
||||||
|
AccountID: accountID,
|
||||||
Role: role,
|
Role: role,
|
||||||
IsServiceUser: isServiceUser,
|
IsServiceUser: isServiceUser,
|
||||||
NonDeletable: nonDeletable,
|
NonDeletable: nonDeletable,
|
||||||
@@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewRegularUser creates a new user with role UserRoleUser
|
// NewRegularUser creates a new user with role UserRoleUser
|
||||||
func NewRegularUser(id string) *User {
|
func NewRegularUser(ID, accountID string) *User {
|
||||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
|
||||||
|
accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||||
func NewAdminUser(id string) *User {
|
func NewAdminUser(ID, accountID string) *User {
|
||||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
|
||||||
|
accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOwnerUser creates a new user with role UserRoleOwner
|
// NewOwnerUser creates a new user with role UserRoleOwner
|
||||||
func NewOwnerUser(id string) *User {
|
func NewOwnerUser(ID, accountID string) *User {
|
||||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
|
return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
|
||||||
|
accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createServiceUser creates a new service user under the given account.
|
// createServiceUser creates a new service user under the given account.
|
||||||
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
|
||||||
|
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
|
|||||||
}
|
}
|
||||||
|
|
||||||
newUserID := uuid.New().String()
|
newUserID := uuid.New().String()
|
||||||
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
|
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
|
||||||
log.Debugf("New User: %v", newUser)
|
log.Debugf("New User: %v", newUser)
|
||||||
account.Users[newUserID] = newUser
|
account.Users[newUserID] = newUser
|
||||||
|
|
||||||
|
|||||||
@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
|||||||
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||||
account.Users["normal_user1"] = NewRegularUser("normal_user1")
|
account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
|
||||||
account.Users["normal_user2"] = NewRegularUser("normal_user2")
|
account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
|
||||||
|
|
||||||
err := store.SaveAccount(account)
|
err := store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
|||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||||
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
|
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
|
||||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||||
delete(account.Users, mockUserID)
|
delete(account.Users, mockUserID)
|
||||||
|
|
||||||
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
|||||||
|
|
||||||
func TestUser_IsAdmin(t *testing.T) {
|
func TestUser_IsAdmin(t *testing.T) {
|
||||||
|
|
||||||
user := NewAdminUser(mockUserID)
|
user := NewAdminUser(mockUserID, mockAccountID)
|
||||||
assert.True(t, user.HasAdminPower())
|
assert.True(t, user.HasAdminPower())
|
||||||
|
|
||||||
user = NewRegularUser(mockUserID)
|
user = NewRegularUser(mockUserID, mockAccountID)
|
||||||
assert.False(t, user.HasAdminPower())
|
assert.False(t, user.HasAdminPower())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create other users
|
// create other users
|
||||||
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
|
||||||
account.Users[adminUserID] = NewAdminUser(adminUserID)
|
account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
|
||||||
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
|
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
|
||||||
err = manager.Store.SaveAccount(account)
|
err = manager.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -12,6 +14,20 @@ import (
|
|||||||
|
|
||||||
func WithCustomDialer() grpc.DialOption {
|
func WithCustomDialer() grpc.DialOption {
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to get current user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||||
|
if currentUser.Uid != "0" {
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
return dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to dial: %s", err)
|
log.Errorf("Failed to dial: %s", err)
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ func RemoveDialerHooks() {
|
|||||||
|
|
||||||
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
||||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return d.Dialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
var resolver *net.Resolver
|
var resolver *net.Resolver
|
||||||
if d.Resolver != nil {
|
if d.Resolver != nil {
|
||||||
resolver = d.Resolver
|
resolver = d.Resolver
|
||||||
@@ -123,6 +127,10 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialUDP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
dialer := NewDialer()
|
dialer := NewDialer()
|
||||||
dialer.LocalAddr = laddr
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
@@ -143,6 +151,10 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialTCP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
dialer := NewDialer()
|
dialer := NewDialer()
|
||||||
dialer.LocalAddr = laddr
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,6 +53,10 @@ func RemoveListenerHooks() {
|
|||||||
// ListenPacket listens on the network address and returns a PacketConn
|
// ListenPacket listens on the network address and returns a PacketConn
|
||||||
// which includes support for write hooks.
|
// which includes support for write hooks.
|
||||||
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return l.ListenConfig.ListenPacket(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("listen packet: %w", err)
|
return nil, fmt.Errorf("listen packet: %w", err)
|
||||||
@@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
|
|||||||
|
|
||||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||||
// which includes support for write and close hooks.
|
// which includes support for write and close hooks.
|
||||||
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
|
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.ListenUDP(network, laddr)
|
||||||
|
}
|
||||||
|
|
||||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
import "github.com/google/uuid"
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
||||||
NetbirdFwmark = 0x1BD00
|
NetbirdFwmark = 0x1BD00
|
||||||
|
|
||||||
|
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionID provides a globally unique identifier for network connections.
|
// ConnectionID provides a globally unique identifier for network connections.
|
||||||
@@ -15,3 +21,7 @@ type ConnectionID string
|
|||||||
func GenerateConnID() ConnectionID {
|
func GenerateConnID() ConnectionID {
|
||||||
return ConnectionID(uuid.NewString())
|
return ConnectionID(uuid.NewString())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CustomRoutingDisabled() bool {
|
||||||
|
return os.Getenv(envDisableCustomRouting) == "true"
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
|||||||
var setErr error
|
var setErr error
|
||||||
|
|
||||||
err := conn.Control(func(fd uintptr) {
|
err := conn.Control(func(fd uintptr) {
|
||||||
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
setErr = SetSocketOpt(int(fd))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("control: %w", err)
|
return fmt.Errorf("control: %w", err)
|
||||||
@@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetSocketOpt(fd int) error {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user