mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Compare commits
55 Commits
v0.26.5
...
feature/op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
515ce9e3af | ||
|
|
89383b7f01 | ||
|
|
db34162733 | ||
|
|
bd761e2177 | ||
|
|
4e1b95a4c6 | ||
|
|
05993af7bf | ||
|
|
9d1cb00570 | ||
|
|
543731df45 | ||
|
|
e6628ec231 | ||
|
|
41d4dd2aff | ||
|
|
30bed57711 | ||
|
|
6960b68322 | ||
|
|
3b3aa18148 | ||
|
|
93045f3e3a | ||
|
|
fd3c1dea8e | ||
|
|
48aff7a26e | ||
|
|
83dfe8e3a3 | ||
|
|
38e10af2d9 | ||
|
|
99854a126a | ||
|
|
a75f982fcd | ||
|
|
e7a6483912 | ||
|
|
30ede299b8 | ||
|
|
e3b76448f3 | ||
|
|
e0de86d6c9 | ||
|
|
5204d07811 | ||
|
|
5ea24ba56e | ||
|
|
d30cf8706a | ||
|
|
15a2feb723 | ||
|
|
91b2f9fc51 | ||
|
|
76702c8a09 | ||
|
|
061f673a4f | ||
|
|
9505805313 | ||
|
|
704c67dec8 | ||
|
|
3ed2f08f3c | ||
|
|
4c83408f27 | ||
|
|
90bd39c740 | ||
|
|
dd0cf41147 | ||
|
|
22b2caffc6 | ||
|
|
c1f66d1354 | ||
|
|
ac0fe6025b | ||
|
|
c28657710a | ||
|
|
3875c29f6b | ||
|
|
9f32ccd453 | ||
|
|
1d1d057e7d | ||
|
|
3461b1bb90 | ||
|
|
3d2a2377c6 | ||
|
|
25f5f26527 | ||
|
|
bb0d5c5baf | ||
|
|
7938295190 | ||
|
|
9af532fe71 | ||
|
|
23a1473797 | ||
|
|
9c2dc05df1 | ||
|
|
40d56e5d29 | ||
|
|
fd23d0c28f | ||
|
|
4fff93a1f2 |
3
.github/workflows/golang-test-darwin.yml
vendored
3
.github/workflows/golang-test-darwin.yml
vendored
@@ -32,6 +32,9 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
macos-go-
|
macos-go-
|
||||||
|
|
||||||
|
- name: Install libpcap
|
||||||
|
run: brew install libpcap
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
|||||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -33,6 +33,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
- name: Check for duplicate constants
|
||||||
|
if: matrix.os == 'ubuntu-latest'
|
||||||
|
run: |
|
||||||
|
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -40,11 +40,12 @@
|
|||||||
|
|
||||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||||
|
|
||||||
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||||
|
|
||||||
### Open-Source Network Security in a Single Platform
|
### Open-Source Network Security in a Single Platform
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
@@ -76,7 +77,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
|||||||
- **Public domain** name pointing to the VM.
|
- **Public domain** name pointing to the VM.
|
||||||
|
|
||||||
**Software requirements:**
|
**Software requirements:**
|
||||||
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||||
- [curl](https://curl.se/) installed.
|
- [curl](https://curl.se/) installed.
|
||||||
@@ -93,9 +94,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
|||||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||||
- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||||
|
|
||||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||||
|
|
||||||
@@ -119,7 +120,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
|
|||||||

|

|
||||||
|
|
||||||
### Testimonials
|
### Testimonials
|
||||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||||
|
|
||||||
### Legal
|
### Legal
|
||||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||||
|
|||||||
@@ -64,6 +64,10 @@ var installCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
svcConfig.Option["OnFailure"] = "restart"
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -93,7 +95,13 @@ func runClient(
|
|||||||
relayProbe *Probe,
|
relayProbe *Probe,
|
||||||
wgProbe *Probe,
|
wgProbe *Probe,
|
||||||
) error {
|
) error {
|
||||||
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
// Check if client was not shut down in a clean way and restore DNS config if required.
|
||||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||||
|
|||||||
@@ -93,6 +93,10 @@ type Engine struct {
|
|||||||
mgmClient mgm.Client
|
mgmClient mgm.Client
|
||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerConns map[string]*peer.Conn
|
peerConns map[string]*peer.Conn
|
||||||
|
|
||||||
|
beforePeerHook peer.BeforeAddPeerHookFunc
|
||||||
|
afterPeerHook peer.AfterRemovePeerHookFunc
|
||||||
|
|
||||||
// rpManager is a Rosenpass manager
|
// rpManager is a Rosenpass manager
|
||||||
rpManager *rosenpass.Manager
|
rpManager *rosenpass.Manager
|
||||||
|
|
||||||
@@ -260,10 +264,14 @@ func (e *Engine) Start() error {
|
|||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||||
if err := e.routeManager.Init(); err != nil {
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||||
e.close()
|
if err != nil {
|
||||||
return fmt.Errorf("init route manager: %w", err)
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
|
} else {
|
||||||
|
e.beforePeerHook = beforePeerHook
|
||||||
|
e.afterPeerHook = afterPeerHook
|
||||||
}
|
}
|
||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
err = e.wgInterfaceCreate()
|
err = e.wgInterfaceCreate()
|
||||||
@@ -786,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
|||||||
FQDN: offlinePeer.GetFqdn(),
|
FQDN: offlinePeer.GetFqdn(),
|
||||||
ConnStatus: peer.StatusDisconnected,
|
ConnStatus: peer.StatusDisconnected,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||||
@@ -809,10 +818,15 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
if _, ok := e.peerConns[peerKey]; !ok {
|
if _, ok := e.peerConns[peerKey]; !ok {
|
||||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("create peer connection: %w", err)
|
||||||
}
|
}
|
||||||
e.peerConns[peerKey] = conn
|
e.peerConns[peerKey] = conn
|
||||||
|
|
||||||
|
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||||
|
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||||
|
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||||
|
}
|
||||||
|
|
||||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||||
@@ -1106,6 +1120,10 @@ func (e *Engine) close() {
|
|||||||
e.dnsServer.Stop()
|
e.dnsServer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.routeManager != nil {
|
||||||
|
e.routeManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
if err := e.wgInterface.Close(); err != nil {
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
@@ -1120,10 +1138,6 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.routeManager != nil {
|
|
||||||
e.routeManager.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
err := e.firewall.Reset()
|
err := e.firewall.Reset()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -20,12 +20,15 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface/bind"
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
iceKeepAliveDefault = 4 * time.Second
|
iceKeepAliveDefault = 4 * time.Second
|
||||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||||
|
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
||||||
|
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
||||||
|
|
||||||
defaultWgKeepAlive = 25 * time.Second
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
)
|
)
|
||||||
@@ -98,6 +101,9 @@ type IceCredentials struct {
|
|||||||
Pwd string
|
Pwd string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
|
||||||
|
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -136,6 +142,10 @@ type Conn struct {
|
|||||||
|
|
||||||
remoteEndpoint *net.UDPAddr
|
remoteEndpoint *net.UDPAddr
|
||||||
remoteConn *ice.Conn
|
remoteConn *ice.Conn
|
||||||
|
|
||||||
|
connID nbnet.ConnectionID
|
||||||
|
beforeAddPeerHooks []BeforeAddPeerHookFunc
|
||||||
|
afterRemovePeerHooks []AfterRemovePeerHookFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// meta holds meta information about a connection
|
// meta holds meta information about a connection
|
||||||
@@ -196,6 +206,7 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
|
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
|
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||||
|
|
||||||
agentConfig := &ice.AgentConfig{
|
agentConfig := &ice.AgentConfig{
|
||||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||||
@@ -210,6 +221,7 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
Net: transportNet,
|
Net: transportNet,
|
||||||
DisconnectedTimeout: &iceDisconnectedTimeout,
|
DisconnectedTimeout: &iceDisconnectedTimeout,
|
||||||
KeepaliveInterval: &iceKeepAlive,
|
KeepaliveInterval: &iceKeepAlive,
|
||||||
|
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.config.DisableIPv6Discovery {
|
if conn.config.DisableIPv6Discovery {
|
||||||
@@ -217,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.agent, err = ice.NewAgent(agentConfig)
|
conn.agent, err = ice.NewAgent(agentConfig)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -273,6 +284,7 @@ func (conn *Conn) Open() error {
|
|||||||
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -332,6 +344,7 @@ func (conn *Conn) Open() error {
|
|||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -389,6 +402,14 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
|||||||
return candidate.Type() == ice.CandidateTypeRelay
|
return candidate.Type() == ice.CandidateTypeRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
|
||||||
|
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
|
||||||
|
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
|
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
@@ -415,6 +436,14 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||||
conn.remoteEndpoint = endpointUdpAddr
|
conn.remoteEndpoint = endpointUdpAddr
|
||||||
|
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||||
|
|
||||||
|
conn.connID = nbnet.GenerateConnID()
|
||||||
|
for _, hook := range conn.beforeAddPeerHooks {
|
||||||
|
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
|
||||||
|
log.Errorf("Before add peer hook failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -437,9 +466,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
LocalIceCandidateType: pair.Local.Type().String(),
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
|
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||||
Direct: !isRelayCandidate(pair.Local),
|
Direct: !isRelayCandidate(pair.Local),
|
||||||
RosenpassEnabled: rosenpassEnabled,
|
RosenpassEnabled: rosenpassEnabled,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||||
peerState.Relayed = true
|
peerState.Relayed = true
|
||||||
@@ -506,6 +536,15 @@ func (conn *Conn) cleanup() error {
|
|||||||
// todo: is it problem if we try to remove a peer what is never existed?
|
// todo: is it problem if we try to remove a peer what is never existed?
|
||||||
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
|
||||||
|
if conn.connID != "" {
|
||||||
|
for _, hook := range conn.afterRemovePeerHooks {
|
||||||
|
if err := hook(conn.connID); err != nil {
|
||||||
|
log.Errorf("After remove peer hook failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.connID = ""
|
||||||
|
|
||||||
if conn.notifyDisconnected != nil {
|
if conn.notifyDisconnected != nil {
|
||||||
conn.notifyDisconnected()
|
conn.notifyDisconnected()
|
||||||
conn.notifyDisconnected = nil
|
conn.notifyDisconnected = nil
|
||||||
@@ -521,6 +560,7 @@ func (conn *Conn) cleanup() error {
|
|||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
|
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
|
||||||
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
|
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
|
||||||
|
envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC"
|
||||||
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
|
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ func iceKeepAlive() time.Duration {
|
|||||||
return iceKeepAliveDefault
|
return iceKeepAliveDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv)
|
log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv)
|
||||||
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
|
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
|
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
|
||||||
@@ -37,7 +38,7 @@ func iceDisconnectedTimeout() time.Duration {
|
|||||||
return iceDisconnectedTimeoutDefault
|
return iceDisconnectedTimeoutDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
|
log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
|
||||||
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
|
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
|
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
|
||||||
@@ -47,6 +48,22 @@ func iceDisconnectedTimeout() time.Duration {
|
|||||||
return time.Duration(disconnectedTimeoutSec) * time.Second
|
return time.Duration(disconnectedTimeoutSec) * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func iceRelayAcceptanceMinWait() time.Duration {
|
||||||
|
iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec)
|
||||||
|
if iceRelayAcceptanceMinWaitEnv == "" {
|
||||||
|
return iceRelayAcceptanceMinWaitDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv)
|
||||||
|
disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault)
|
||||||
|
return iceRelayAcceptanceMinWaitDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Duration(disconnectedTimeoutSec) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
func hasICEForceRelayConn() bool {
|
func hasICEForceRelayConn() bool {
|
||||||
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
|
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
|
||||||
return strings.ToLower(disconnectedTimeoutEnv) == "true"
|
return strings.ToLower(disconnectedTimeoutEnv) == "true"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
// State contains the latest state of a peer
|
// State contains the latest state of a peer
|
||||||
type State struct {
|
type State struct {
|
||||||
|
Mux *sync.RWMutex
|
||||||
IP string
|
IP string
|
||||||
PubKey string
|
PubKey string
|
||||||
FQDN string
|
FQDN string
|
||||||
@@ -30,7 +31,38 @@ type State struct {
|
|||||||
BytesRx int64
|
BytesRx int64
|
||||||
Latency time.Duration
|
Latency time.Duration
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
Routes map[string]struct{}
|
routes map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRoute add a single route to routes map
|
||||||
|
func (s *State) AddRoute(network string) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
if s.routes == nil {
|
||||||
|
s.routes = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
s.routes[network] = struct{}{}
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRoutes set state routes
|
||||||
|
func (s *State) SetRoutes(routes map[string]struct{}) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
s.routes = routes
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRoute removes a route from the network amp
|
||||||
|
func (s *State) DeleteRoute(network string) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
delete(s.routes, network)
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutes return routes map
|
||||||
|
func (s *State) GetRoutes() map[string]struct{} {
|
||||||
|
s.Mux.RLock()
|
||||||
|
defer s.Mux.RUnlock()
|
||||||
|
return s.routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalPeerState contains the latest state of the local peer
|
// LocalPeerState contains the latest state of the local peer
|
||||||
@@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
|||||||
PubKey: peerPubKey,
|
PubKey: peerPubKey,
|
||||||
ConnStatus: StatusDisconnected,
|
ConnStatus: StatusDisconnected,
|
||||||
FQDN: fqdn,
|
FQDN: fqdn,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
d.peerListChangedForNotification = true
|
d.peerListChangedForNotification = true
|
||||||
return nil
|
return nil
|
||||||
@@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
peerState.IP = receivedState.IP
|
peerState.IP = receivedState.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
if receivedState.Routes != nil {
|
if receivedState.GetRoutes() != nil {
|
||||||
peerState.Routes = receivedState.Routes
|
peerState.SetRoutes(receivedState.GetRoutes())
|
||||||
}
|
}
|
||||||
|
|
||||||
skipNotification := shouldSkipNotify(receivedState, peerState)
|
skipNotification := shouldSkipNotify(receivedState, peerState)
|
||||||
@@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
|
|||||||
s, ok := gstatus.FromError(d.managementError)
|
s, ok := gstatus.FromError(d.managementError)
|
||||||
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
return true
|
return true
|
||||||
|
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package peer
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -18,6 +19,7 @@ type routerPeerStatus struct {
|
|||||||
connected bool
|
connected bool
|
||||||
relayed bool
|
relayed bool
|
||||||
direct bool
|
direct bool
|
||||||
|
latency time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type routesUpdate struct {
|
type routesUpdate struct {
|
||||||
@@ -68,6 +70,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
||||||
relayed: peerStatus.Relayed,
|
relayed: peerStatus.Relayed,
|
||||||
direct: peerStatus.Direct,
|
direct: peerStatus.Direct,
|
||||||
|
latency: peerStatus.Latency,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return routePeerStatuses
|
return routePeerStatuses
|
||||||
@@ -83,11 +86,13 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
// * Non-relayed: Routes without relays are preferred.
|
// * Non-relayed: Routes without relays are preferred.
|
||||||
// * Direct connections: Routes with direct peer connections are favored.
|
// * Direct connections: Routes with direct peer connections are favored.
|
||||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||||
|
// * Latency: Routes with lower latency are prioritized.
|
||||||
//
|
//
|
||||||
// It returns the ID of the selected optimal route.
|
// It returns the ID of the selected optimal route.
|
||||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
||||||
chosen := ""
|
chosen := ""
|
||||||
chosenScore := 0
|
chosenScore := float64(0)
|
||||||
|
currScore := float64(0)
|
||||||
|
|
||||||
currID := ""
|
currID := ""
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
@@ -95,7 +100,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
tempScore := 0
|
tempScore := float64(0)
|
||||||
peerStatus, found := routePeerStatuses[r.ID]
|
peerStatus, found := routePeerStatuses[r.ID]
|
||||||
if !found || !peerStatus.connected {
|
if !found || !peerStatus.connected {
|
||||||
continue
|
continue
|
||||||
@@ -103,9 +108,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
|
|
||||||
if r.Metric < route.MaxMetric {
|
if r.Metric < route.MaxMetric {
|
||||||
metricDiff := route.MaxMetric - r.Metric
|
metricDiff := route.MaxMetric - r.Metric
|
||||||
tempScore = metricDiff * 10
|
tempScore = float64(metricDiff) * 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
||||||
|
latency := time.Second
|
||||||
|
if peerStatus.latency != 0 {
|
||||||
|
latency = peerStatus.latency
|
||||||
|
} else {
|
||||||
|
log.Warnf("peer %s has 0 latency", r.Peer)
|
||||||
|
}
|
||||||
|
tempScore += 1 - latency.Seconds()
|
||||||
|
|
||||||
if !peerStatus.relayed {
|
if !peerStatus.relayed {
|
||||||
tempScore++
|
tempScore++
|
||||||
}
|
}
|
||||||
@@ -114,7 +128,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
tempScore++
|
tempScore++
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
|
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
|
||||||
chosen = r.ID
|
chosen = r.ID
|
||||||
chosenScore = tempScore
|
chosenScore = tempScore
|
||||||
}
|
}
|
||||||
@@ -123,18 +137,26 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
chosen = r.ID
|
chosen = r.ID
|
||||||
chosenScore = tempScore
|
chosenScore = tempScore
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.ID == currID {
|
||||||
|
currScore = tempScore
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if chosen == "" {
|
switch {
|
||||||
|
case chosen == "":
|
||||||
var peers []string
|
var peers []string
|
||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
peers = append(peers, r.Peer)
|
peers = append(peers, r.Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
||||||
|
case chosen != currID:
|
||||||
} else if chosen != currID {
|
if currScore != 0 && currScore < chosenScore+0.1 {
|
||||||
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
|
return currID
|
||||||
|
} else {
|
||||||
|
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return chosen
|
return chosen
|
||||||
@@ -174,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
|||||||
return fmt.Errorf("get peer state: %v", err)
|
return fmt.Errorf("get peer state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(state.Routes, c.network.String())
|
state.DeleteRoute(c.network.String())
|
||||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
}
|
}
|
||||||
@@ -193,7 +215,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
|||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
|
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil {
|
||||||
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,7 +256,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// otherwise add the route to the system
|
// otherwise add the route to the system
|
||||||
if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
|
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil {
|
||||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||||
}
|
}
|
||||||
@@ -246,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get peer state: %v", err)
|
log.Errorf("Failed to get peer state: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if state.Routes == nil {
|
state.AddRoute(c.network.String())
|
||||||
state.Routes = map[string]struct{}{}
|
|
||||||
}
|
|
||||||
state.Routes[c.network.String()] = struct{}{}
|
|
||||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package routemanager
|
|||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
statuses map[string]routerPeerStatus
|
statuses map[string]routerPeerStatus
|
||||||
expectedRouteID string
|
expectedRouteID string
|
||||||
currentRoute *route.Route
|
currentRoute string
|
||||||
existingRoutes map[string]*route.Route
|
existingRoutes map[string]*route.Route
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer1",
|
Peer: "peer1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer1",
|
Peer: "peer1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer1",
|
Peer: "peer1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer1",
|
Peer: "peer1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "",
|
expectedRouteID: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer2",
|
Peer: "peer2",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer2",
|
Peer: "peer2",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer2",
|
Peer: "peer2",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "multiple connected peers with different latencies",
|
||||||
|
statuses: map[string]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
latency: 300 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[string]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "",
|
||||||
|
expectedRouteID: "route2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should ignore routes with latency 0",
|
||||||
|
statuses: map[string]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[string]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "",
|
||||||
|
expectedRouteID: "route2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "current route with similar score and similar but slightly worse latency should not change",
|
||||||
|
statuses: map[string]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 12 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[string]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "current chosen route doesn't exist anymore",
|
||||||
|
statuses: map[string]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 20 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[string]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "routeDoesntExistAnymore",
|
||||||
|
expectedRouteID: "route2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
currentRoute := &route.Route{
|
||||||
|
ID: "routeDoesntExistAnymore",
|
||||||
|
}
|
||||||
|
if tc.currentRoute != "" {
|
||||||
|
currentRoute = tc.existingRoutes[tc.currentRoute]
|
||||||
|
}
|
||||||
|
|
||||||
// create new clientNetwork
|
// create new clientNetwork
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
network: netip.MustParsePrefix("192.168.0.0/24"),
|
network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
routes: tc.existingRoutes,
|
routes: tc.existingRoutes,
|
||||||
chosenRoute: tc.currentRoute,
|
chosenRoute: currentRoute,
|
||||||
}
|
}
|
||||||
|
|
||||||
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ package routemanager
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -14,6 +16,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,7 +27,7 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
|||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() error
|
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
@@ -65,16 +68,25 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up the routing
|
// Init sets up the routing
|
||||||
func (m *DefaultManager) Init() error {
|
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
if nbnet.CustomRoutingDisabled() {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
if err := cleanupRouting(); err != nil {
|
if err := cleanupRouting(); err != nil {
|
||||||
log.Warnf("Failed cleaning up routing: %v", err)
|
log.Warnf("Failed cleaning up routing: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := setupRouting(); err != nil {
|
mgmtAddress := m.statusRecorder.GetManagementState().URL
|
||||||
return fmt.Errorf("setup routing: %w", err)
|
signalAddress := m.statusRecorder.GetSignalState().URL
|
||||||
|
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
||||||
|
|
||||||
|
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||||
}
|
}
|
||||||
log.Info("Routing setup complete")
|
log.Info("Routing setup complete")
|
||||||
return nil
|
return beforePeerHook, afterPeerHook, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||||
@@ -92,11 +104,15 @@ func (m *DefaultManager) Stop() {
|
|||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
m.serverRouter.cleanUp()
|
m.serverRouter.cleanUp()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !nbnet.CustomRoutingDisabled() {
|
||||||
if err := cleanupRouting(); err != nil {
|
if err := cleanupRouting(); err != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Routing cleanup complete")
|
log.Info("Routing cleanup complete")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.ctx = nil
|
m.ctx = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,16 +219,38 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||||
if runtime.GOOS == "linux" {
|
if !nbnet.CustomRoutingDisabled() {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux", "windows", "darwin":
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||||
// we skip this prefix management
|
// we skip this prefix management
|
||||||
if prefix.Bits() < minRangeBits {
|
if prefix.Bits() <= minRangeBits {
|
||||||
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
||||||
version.NetbirdVersion(), prefix)
|
version.NetbirdVersion(), prefix)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
|
||||||
|
func resolveURLsToIPs(urls []string) []net.IP {
|
||||||
|
var ips []net.IP
|
||||||
|
for _, rawurl := range urls {
|
||||||
|
u, err := url.Parse(rawurl)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to parse url %s: %v", rawurl, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ipAddrs, err := net.LookupIP(u.Hostname())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ips = append(ips, ipAddrs...)
|
||||||
|
}
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
removeSrvRouter bool
|
removeSrvRouter bool
|
||||||
serverRoutesExpected int
|
serverRoutesExpected int
|
||||||
clientNetworkWatchersExpected int
|
clientNetworkWatchersExpected int
|
||||||
clientNetworkWatchersExpectedLinux int
|
clientNetworkWatchersExpectedAllowed int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Should create 2 client networks",
|
name: "Should create 2 client networks",
|
||||||
@@ -203,7 +203,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
clientNetworkWatchersExpected: 0,
|
clientNetworkWatchersExpected: 0,
|
||||||
clientNetworkWatchersExpectedLinux: 1,
|
clientNetworkWatchersExpectedAllowed: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Remove 1 Client Route",
|
name: "Remove 1 Client Route",
|
||||||
@@ -417,7 +417,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||||
err = routeManager.Init()
|
|
||||||
|
_, _, err = routeManager.Init()
|
||||||
|
|
||||||
require.NoError(t, err, "should init route manager")
|
require.NoError(t, err, "should init route manager")
|
||||||
defer routeManager.Stop()
|
defer routeManager.Stop()
|
||||||
|
|
||||||
@@ -434,8 +436,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||||
if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 {
|
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
||||||
expectedWatchers = testCase.clientNetworkWatchersExpectedLinux
|
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
|
||||||
}
|
}
|
||||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -16,8 +17,8 @@ type MockManager struct {
|
|||||||
StopFunc func()
|
StopFunc func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) Init() error {
|
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
return nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||||
|
|||||||
126
client/internal/routemanager/routemanager.go
Normal file
126
client/internal/routemanager/routemanager.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ref struct {
|
||||||
|
count int
|
||||||
|
nexthop netip.Addr
|
||||||
|
intf string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouteManager struct {
|
||||||
|
// refCountMap keeps track of the reference ref for prefixes
|
||||||
|
refCountMap map[netip.Prefix]ref
|
||||||
|
// prefixMap keeps track of the prefixes associated with a connection ID for removal
|
||||||
|
prefixMap map[nbnet.ConnectionID][]netip.Prefix
|
||||||
|
addRoute AddRouteFunc
|
||||||
|
removeRoute RemoveRouteFunc
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error)
|
||||||
|
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error
|
||||||
|
|
||||||
|
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
|
||||||
|
// TODO: read initial routing table into refCountMap
|
||||||
|
return &RouteManager{
|
||||||
|
refCountMap: map[netip.Prefix]ref{},
|
||||||
|
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
|
||||||
|
addRoute: addRoute,
|
||||||
|
removeRoute: removeRoute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||||
|
rm.mutex.Lock()
|
||||||
|
defer rm.mutex.Unlock()
|
||||||
|
|
||||||
|
ref := rm.refCountMap[prefix]
|
||||||
|
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
|
||||||
|
|
||||||
|
// Add route to the system, only if it's a new prefix
|
||||||
|
if ref.count == 0 {
|
||||||
|
log.Debugf("Adding route for prefix %s", prefix)
|
||||||
|
nexthop, intf, err := rm.addRoute(prefix)
|
||||||
|
if errors.Is(err, ErrRouteNotFound) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrRouteNotAllowed) {
|
||||||
|
log.Debugf("Adding route for prefix %s: %s", prefix, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
ref.nexthop = nexthop
|
||||||
|
ref.intf = intf
|
||||||
|
}
|
||||||
|
|
||||||
|
ref.count++
|
||||||
|
rm.refCountMap[prefix] = ref
|
||||||
|
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
|
||||||
|
rm.mutex.Lock()
|
||||||
|
defer rm.mutex.Unlock()
|
||||||
|
|
||||||
|
prefixes, ok := rm.prefixMap[connID]
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("No prefixes found for connection ID %s", connID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
ref := rm.refCountMap[prefix]
|
||||||
|
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
|
||||||
|
if ref.count == 1 {
|
||||||
|
log.Debugf("Removing route for prefix %s", prefix)
|
||||||
|
// TODO: don't fail if the route is not found
|
||||||
|
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(rm.refCountMap, prefix)
|
||||||
|
} else {
|
||||||
|
ref.count--
|
||||||
|
rm.refCountMap[prefix] = ref
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(rm.prefixMap, connID)
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush removes all references and routes from the system
|
||||||
|
func (rm *RouteManager) Flush() error {
|
||||||
|
rm.mutex.Lock()
|
||||||
|
defer rm.mutex.Unlock()
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for prefix := range rm.refCountMap {
|
||||||
|
log.Debugf("Removing route for prefix %s", prefix)
|
||||||
|
ref := rm.refCountMap[prefix]
|
||||||
|
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rm.refCountMap = map[netip.Prefix]ref{}
|
||||||
|
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
@@ -155,11 +155,13 @@ func (m *defaultServerRouter) cleanUp() {
|
|||||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
state.Routes = nil
|
state.Routes = nil
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
||||||
parsed, err := netip.ParsePrefix(source)
|
parsed, err := netip.ParsePrefix(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
428
client/internal/routemanager/systemops.go
Normal file
428
client/internal/routemanager/systemops.go
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/libp2p/go-netroute"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||||
|
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||||
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||||
|
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||||
|
|
||||||
|
var ErrRouteNotFound = errors.New("route not found")
|
||||||
|
var ErrRouteNotAllowed = errors.New("route not allowed")
|
||||||
|
|
||||||
|
// TODO: fix: for default our wg address now appears as the default gw
|
||||||
|
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||||
|
addr := netip.IPv4Unspecified()
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
addr = netip.IPv6Unspecified()
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultGateway, _, err := getNextHop(addr)
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
return fmt.Errorf("get existing route gateway: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !prefix.Contains(defaultGateway) {
|
||||||
|
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
||||||
|
if defaultGateway.Is6() {
|
||||||
|
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := existsInRouteTable(gatewayPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitIntf string
|
||||||
|
gatewayHop, intf, err := getNextHop(defaultGateway)
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||||
|
}
|
||||||
|
if intf != nil {
|
||||||
|
exitIntf = intf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||||
|
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||||
|
r, err := netroute.New()
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||||
|
}
|
||||||
|
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get route for %s: %v", ip, err)
|
||||||
|
return netip.Addr{}, nil, ErrRouteNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||||
|
if gateway == nil {
|
||||||
|
if preferredSrc == nil {
|
||||||
|
return netip.Addr{}, nil, ErrRouteNotFound
|
||||||
|
}
|
||||||
|
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
||||||
|
|
||||||
|
addr, err := ipToAddr(preferredSrc, intf)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
|
||||||
|
}
|
||||||
|
return addr.Unmap(), intf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := ipToAddr(gateway, intf)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr, intf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
||||||
|
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||||
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
||||||
|
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
addr = addr.WithZone(strconv.Itoa(intf.Index))
|
||||||
|
} else {
|
||||||
|
addr = addr.WithZone(intf.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr.Unmap(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||||
|
routes, err := getRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get routes from table: %w", err)
|
||||||
|
}
|
||||||
|
for _, tableRoute := range routes {
|
||||||
|
if tableRoute == prefix {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||||
|
routes, err := getRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get routes from table: %w", err)
|
||||||
|
}
|
||||||
|
for _, tableRoute := range routes {
|
||||||
|
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||||
|
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||||
|
func addRouteToNonVPNIntf(
|
||||||
|
prefix netip.Prefix,
|
||||||
|
vpnIntf *iface.WGIface,
|
||||||
|
initialNextHop netip.Addr,
|
||||||
|
initialIntf *net.Interface,
|
||||||
|
) (netip.Addr, string, error) {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
switch {
|
||||||
|
case addr.IsLoopback(),
|
||||||
|
addr.IsLinkLocalUnicast(),
|
||||||
|
addr.IsLinkLocalMulticast(),
|
||||||
|
addr.IsInterfaceLocalMulticast(),
|
||||||
|
addr.IsUnspecified(),
|
||||||
|
addr.IsMulticast():
|
||||||
|
|
||||||
|
return netip.Addr{}, "", ErrRouteNotAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||||
|
nexthop, intf, err := getNextHop(addr)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
||||||
|
exitNextHop := nexthop
|
||||||
|
var exitIntf string
|
||||||
|
if intf != nil {
|
||||||
|
exitIntf = intf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||||
|
}
|
||||||
|
|
||||||
|
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||||
|
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
|
||||||
|
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||||
|
exitNextHop = initialNextHop
|
||||||
|
if initialIntf != nil {
|
||||||
|
exitIntf = initialIntf.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
||||||
|
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return exitNextHop, exitIntf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||||
|
// in two /1 prefixes to avoid replacing the existing default route
|
||||||
|
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if prefix == defaultv4 {
|
||||||
|
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove once IPv6 is supported on the interface
|
||||||
|
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
} else if prefix == defaultv6 {
|
||||||
|
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return addNonExistingRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||||
|
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
ok, err := existsInRouteTable(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("exists in route table: %w", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = isSubRange(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sub range: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
err := addRouteForCurrentDefaultGateway(prefix)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addToRouteTable(prefix, netip.Addr{}, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||||
|
// it will remove the split /1 prefixes
|
||||||
|
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if prefix == defaultv4 {
|
||||||
|
var result *multierror.Error
|
||||||
|
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove once IPv6 is supported on the interface
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
} else if prefix == defaultv6 {
|
||||||
|
var result *multierror.Error
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
||||||
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("parse IP address: %s", ip)
|
||||||
|
}
|
||||||
|
addr = addr.Unmap()
|
||||||
|
|
||||||
|
var prefixLength int
|
||||||
|
switch {
|
||||||
|
case addr.Is4():
|
||||||
|
prefixLength = 32
|
||||||
|
case addr.Is6():
|
||||||
|
prefixLength = 128
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||||
|
return &prefix, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||||
|
}
|
||||||
|
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
*routeManager = NewRouteManager(
|
||||||
|
func(prefix netip.Prefix) (netip.Addr, string, error) {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
nexthop, intf := initialNextHopV4, initialIntfV4
|
||||||
|
if addr.Is6() {
|
||||||
|
nexthop, intf = initialNextHopV6, initialIntfV6
|
||||||
|
}
|
||||||
|
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
||||||
|
},
|
||||||
|
removeFromRouteTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
return setupHooks(*routeManager, initAddresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
||||||
|
if routeManager == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Remove hooks selectively
|
||||||
|
nbnet.RemoveDialerHooks()
|
||||||
|
nbnet.RemoveListenerHooks()
|
||||||
|
|
||||||
|
if err := routeManager.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush route manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||||
|
prefix, err := getPrefixFromIP(ip)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
||||||
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
afterHook := func(connID nbnet.ConnectionID) error {
|
||||||
|
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
||||||
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range initAddresses {
|
||||||
|
if err := beforeHook("init", ip); err != nil {
|
||||||
|
log.Errorf("Failed to add route reference: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for _, ip := range resolvedIPs {
|
||||||
|
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||||
|
}
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||||
|
return afterHook(connID)
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||||
|
return beforeHook(connID, ip.IP)
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||||
|
return afterHook(connID)
|
||||||
|
})
|
||||||
|
|
||||||
|
return beforeHook, afterHook, nil
|
||||||
|
}
|
||||||
@@ -1,13 +1,33 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
|
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRouting() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
|
func enableIPForwarding() error {
|
||||||
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addVPNRoute(netip.Prefix, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeVPNRoute(netip.Prefix, string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -51,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(m.Addrs) < 3 {
|
||||||
|
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
addr, ok := toNetIPAddr(m.Addrs[0])
|
addr, ok := toNetIPAddr(m.Addrs[0])
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
mask, ok := toNetIPMASK(m.Addrs[2])
|
cidr := 32
|
||||||
|
if mask := m.Addrs[2]; mask != nil {
|
||||||
|
cidr, ok = toCIDR(mask)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
cidr, _ := mask.Size()
|
}
|
||||||
|
|
||||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
routePrefix := netip.PrefixFrom(addr, cidr)
|
||||||
if routePrefix.IsValid() {
|
if routePrefix.IsValid() {
|
||||||
@@ -73,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
|
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case *route.Inet4Addr:
|
case *route.Inet4Addr:
|
||||||
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
return netip.AddrFrom4(t.IP), true
|
||||||
addr := netip.MustParseAddr(ip.String())
|
|
||||||
return addr, true
|
|
||||||
default:
|
default:
|
||||||
return netip.Addr{}, false
|
return netip.Addr{}, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toNetIPMASK(a route.Addr) (net.IPMask, bool) {
|
func toCIDR(a route.Addr) (int, bool) {
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case *route.Inet4Addr:
|
case *route.Inet4Addr:
|
||||||
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
||||||
return mask, true
|
cidr, _ := mask.Size()
|
||||||
|
return cidr, true
|
||||||
default:
|
default:
|
||||||
return nil, false
|
return 0, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import "net/netip"
|
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
|
||||||
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
|
||||||
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
|
|
||||||
}
|
|
||||||
89
client/internal/routemanager/systemops_darwin.go
Normal file
89
client/internal/routemanager/systemops_darwin.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
var routeManager *RouteManager
|
||||||
|
|
||||||
|
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRouting() error {
|
||||||
|
return cleanupRoutingWithRouteManager(routeManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return routeCmd("add", prefix, nexthop, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return routeCmd("delete", prefix, nexthop, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
inet := "-inet"
|
||||||
|
network := prefix.String()
|
||||||
|
if prefix.IsSingleIP() {
|
||||||
|
network = prefix.Addr().String()
|
||||||
|
}
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
inet = "-inet6"
|
||||||
|
// Special case for IPv6 split default route, pointing to the wg interface fails
|
||||||
|
// TODO: Remove once we have IPv6 support on the interface
|
||||||
|
if prefix.Bits() == 1 {
|
||||||
|
intf = "lo0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{"-n", action, inet, network}
|
||||||
|
if nexthop.IsValid() {
|
||||||
|
args = append(args, nexthop.Unmap().String())
|
||||||
|
} else if intf != "" {
|
||||||
|
args = append(args, "-interface", intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := retryRouteCmd(args); err != nil {
|
||||||
|
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func retryRouteCmd(args []string) error {
|
||||||
|
operation := func() error {
|
||||||
|
out, err := exec.Command("route", args...).CombinedOutput()
|
||||||
|
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||||
|
// https://github.com/golang/go/issues/45736
|
||||||
|
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
|
||||||
|
return err
|
||||||
|
} else if err != nil {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expBackOff := backoff.NewExponentialBackOff()
|
||||||
|
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||||
|
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||||
|
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||||
|
|
||||||
|
err := backoff.Retry(operation, expBackOff)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route cmd retry failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
138
client/internal/routemanager/systemops_darwin_test.go
Normal file
138
client/internal/routemanager/systemops_darwin_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var expectedVPNint = "utun100"
|
||||||
|
var expectedExternalInt = "lo0"
|
||||||
|
var expectedInternalInt = "lo0"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
testCases = append(testCases, []testCase{
|
||||||
|
{
|
||||||
|
name: "To more specific route without custom dialer via vpn",
|
||||||
|
destination: "10.10.0.2:53",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentRoutes(t *testing.T) {
|
||||||
|
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||||
|
intf := "lo0"
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 1024; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ip netip.Addr) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
|
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||||
|
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}(baseIP)
|
||||||
|
baseIP = baseIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||||
|
|
||||||
|
for i := 0; i < 1024; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ip netip.Addr) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
|
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||||
|
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}(baseIP)
|
||||||
|
baseIP = baseIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||||
|
require.NoError(t, err, "Failed to create loopback alias")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||||
|
})
|
||||||
|
|
||||||
|
return "lo0"
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var originalNexthop net.IP
|
||||||
|
if dstCIDR == "0.0.0.0/0" {
|
||||||
|
var err error
|
||||||
|
originalNexthop, err = fetchOriginalGateway()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to fetch original gateway: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||||
|
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if originalNexthop != nil {
|
||||||
|
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||||
|
assert.NoError(t, err, "Failed to restore original route")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||||
|
require.NoError(t, err, "Failed to add route")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove route")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchOriginalGateway() (net.IP, error) {
|
||||||
|
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil, fmt.Errorf("gateway not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.ParseIP(matches[1]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||||
|
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
||||||
|
|
||||||
|
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||||
|
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
||||||
|
}
|
||||||
@@ -1,13 +1,33 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
|
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRouting() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
|
func enableIPForwarding() error {
|
||||||
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addVPNRoute(netip.Prefix, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeVPNRoute(netip.Prefix, string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,12 +9,16 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,16 +32,31 @@ const (
|
|||||||
rtTablesPath = "/etc/iproute2/rt_tables"
|
rtTablesPath = "/etc/iproute2/rt_tables"
|
||||||
|
|
||||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||||
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||||
|
|
||||||
|
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||||
|
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||||
|
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||||
|
|
||||||
|
var routeManager = &RouteManager{}
|
||||||
|
|
||||||
|
// originalSysctl stores the original sysctl values before they are modified
|
||||||
|
var originalSysctl map[string]int
|
||||||
|
|
||||||
|
// determines whether to use the legacy routing setup
|
||||||
|
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||||
|
|
||||||
|
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
|
||||||
|
var sysctlFailed bool
|
||||||
|
|
||||||
type ruleParams struct {
|
type ruleParams struct {
|
||||||
|
priority int
|
||||||
fwmark int
|
fwmark int
|
||||||
tableID int
|
tableID int
|
||||||
family int
|
family int
|
||||||
priority int
|
|
||||||
invert bool
|
invert bool
|
||||||
suppressPrefix int
|
suppressPrefix int
|
||||||
description string
|
description string
|
||||||
@@ -45,10 +64,10 @@ type ruleParams struct {
|
|||||||
|
|
||||||
func getSetupRules() []ruleParams {
|
func getSetupRules() []ruleParams {
|
||||||
return []ruleParams{
|
return []ruleParams{
|
||||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "add rule v4 netbird"},
|
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "add rule v6 netbird"},
|
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
|
||||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "add rule with suppress prefixlen v4"},
|
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
|
||||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "add rule with suppress prefixlen v6"},
|
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,13 +81,23 @@ func getSetupRules() []ruleParams {
|
|||||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
//
|
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||||
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
|
if isLegacy {
|
||||||
func setupRouting() (err error) {
|
log.Infof("Using legacy routing setup")
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
|
||||||
if err = addRoutingTableName(); err != nil {
|
if err = addRoutingTableName(); err != nil {
|
||||||
log.Errorf("Error adding routing table name: %v", err)
|
log.Errorf("Error adding routing table name: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
originalValues, err := setupSysctl(wgIface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error setting up sysctl: %v", err)
|
||||||
|
sysctlFailed = true
|
||||||
|
}
|
||||||
|
originalSysctl = originalValues
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||||
@@ -80,17 +109,26 @@ func setupRouting() (err error) {
|
|||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := addRule(rule); err != nil {
|
if err := addRule(rule); err != nil {
|
||||||
return fmt.Errorf("%s: %w", rule.description, err)
|
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
|
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||||
|
isLegacy = true
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||||
func cleanupRouting() error {
|
func cleanupRouting() error {
|
||||||
|
if isLegacy {
|
||||||
|
return cleanupRoutingWithRouteManager(routeManager)
|
||||||
|
}
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||||
@@ -102,165 +140,79 @@ func cleanupRouting() error {
|
|||||||
|
|
||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := removeAllRules(rule); err != nil {
|
if err := removeRule(rule); err != nil {
|
||||||
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := cleanupSysctl(originalSysctl); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
|
||||||
|
}
|
||||||
|
originalSysctl = nil
|
||||||
|
sysctlFailed = false
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error {
|
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 2
|
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if isLegacy {
|
||||||
|
return genericAddVPNRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
|
||||||
|
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
// TODO remove this once we have ipv6 support
|
||||||
if prefix == defaultv4 {
|
if prefix == defaultv4 {
|
||||||
if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("add blackhole: %w", err)
|
return fmt.Errorf("add blackhole: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("add route: %w", err)
|
return fmt.Errorf("add route: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error {
|
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if isLegacy {
|
||||||
|
return genericRemoveVPNRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
// TODO remove this once we have ipv6 support
|
||||||
if prefix == defaultv4 {
|
if prefix == defaultv4 {
|
||||||
if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("remove unreachable route: %w", err)
|
return fmt.Errorf("remove unreachable route: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("remove route: %w", err)
|
return fmt.Errorf("remove route: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4)
|
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
|
||||||
}
|
|
||||||
|
|
||||||
// addRoute adds a route to a specific routing table identified by tableID.
|
|
||||||
func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
|
|
||||||
route := &netlink.Route{
|
|
||||||
Scope: netlink.SCOPE_UNIVERSE,
|
|
||||||
Table: tableID,
|
|
||||||
Family: family,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix != nil {
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
return nil, fmt.Errorf("get v4 routes: %w", err)
|
||||||
}
|
}
|
||||||
route.Dst = ipNet
|
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
|
||||||
}
|
|
||||||
|
|
||||||
if err := addNextHop(addr, intf, route); err != nil {
|
|
||||||
return fmt.Errorf("add gateway and device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) {
|
|
||||||
return fmt.Errorf("netlink add route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
|
||||||
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
|
||||||
// tableID specifies the routing table to which the unreachable route will be added.
|
|
||||||
func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
return nil, fmt.Errorf("get v6 routes: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
route := &netlink.Route{
|
|
||||||
Type: syscall.RTN_UNREACHABLE,
|
|
||||||
Table: tableID,
|
|
||||||
Family: ipFamily,
|
|
||||||
Dst: ipNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) {
|
|
||||||
return fmt.Errorf("netlink add unreachable route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
route := &netlink.Route{
|
|
||||||
Type: syscall.RTN_UNREACHABLE,
|
|
||||||
Table: tableID,
|
|
||||||
Family: ipFamily,
|
|
||||||
Dst: ipNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
|
|
||||||
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
return append(v4Routes, v6Routes...), nil
|
||||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
|
||||||
func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
route := &netlink.Route{
|
|
||||||
Scope: netlink.SCOPE_UNIVERSE,
|
|
||||||
Table: tableID,
|
|
||||||
Family: family,
|
|
||||||
Dst: ipNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := addNextHop(addr, intf, route); err != nil {
|
|
||||||
return fmt.Errorf("add gateway and device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
|
|
||||||
return fmt.Errorf("netlink remove route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func flushRoutes(tableID, family int) error {
|
|
||||||
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list routes from table %d: %w", tableID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result *multierror.Error
|
|
||||||
for i := range routes {
|
|
||||||
route := routes[i]
|
|
||||||
// unreachable default routes don't come back with Dst set
|
|
||||||
if route.Gw == nil && route.Src == nil && route.Dst == nil {
|
|
||||||
if family == netlink.FAMILY_V4 {
|
|
||||||
routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
|
|
||||||
} else {
|
|
||||||
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := netlink.RouteDel(&routes[i]); err != nil {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRoutes fetches routes from a specific routing table identified by tableID.
|
// getRoutes fetches routes from a specific routing table identified by tableID.
|
||||||
@@ -291,25 +243,130 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
|||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
// addRoute adds a route to a specific routing table identified by tableID.
|
||||||
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
||||||
|
route := &netlink.Route{
|
||||||
|
Scope: netlink.SCOPE_UNIVERSE,
|
||||||
|
Table: tableID,
|
||||||
|
Family: getAddressFamily(prefix),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
route.Dst = ipNet
|
||||||
|
|
||||||
|
if err := addNextHop(addr, intf, route); err != nil {
|
||||||
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("netlink add route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if it is already enabled
|
|
||||||
// see more: https://github.com/netbirdio/netbird/issues/872
|
|
||||||
if len(bytes) > 0 && bytes[0] == 49 {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gosec
|
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
||||||
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
|
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
||||||
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
|
// tableID specifies the routing table to which the unreachable route will be added.
|
||||||
|
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
route := &netlink.Route{
|
||||||
|
Type: syscall.RTN_UNREACHABLE,
|
||||||
|
Table: tableID,
|
||||||
|
Family: getAddressFamily(prefix),
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("netlink add unreachable route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
route := &netlink.Route{
|
||||||
|
Type: syscall.RTN_UNREACHABLE,
|
||||||
|
Table: tableID,
|
||||||
|
Family: getAddressFamily(prefix),
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||||
|
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
route := &netlink.Route{
|
||||||
|
Scope: netlink.SCOPE_UNIVERSE,
|
||||||
|
Table: tableID,
|
||||||
|
Family: getAddressFamily(prefix),
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := addNextHop(addr, intf, route); err != nil {
|
||||||
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("netlink remove route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flushRoutes(tableID, family int) error {
|
||||||
|
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list routes from table %d: %w", tableID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for i := range routes {
|
||||||
|
route := routes[i]
|
||||||
|
// unreachable default routes don't come back with Dst set
|
||||||
|
if route.Gw == nil && route.Src == nil && route.Dst == nil {
|
||||||
|
if family == netlink.FAMILY_V4 {
|
||||||
|
routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
|
||||||
|
} else {
|
||||||
|
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
func enableIPForwarding() error {
|
||||||
|
_, err := setSysctl(ipv4ForwardingPath, 1, false)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
||||||
// and verifies if existing names start with "netbird_".
|
// and verifies if existing names start with "netbird_".
|
||||||
func entryExists(file *os.File, id int) (bool, error) {
|
func entryExists(file *os.File, id int) (bool, error) {
|
||||||
@@ -385,7 +442,7 @@ func addRule(params ruleParams) error {
|
|||||||
rule.Invert = params.invert
|
rule.Invert = params.invert
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleAdd(rule); err != nil {
|
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
return fmt.Errorf("add routing rule: %w", err)
|
return fmt.Errorf("add routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -402,43 +459,116 @@ func removeRule(params ruleParams) error {
|
|||||||
rule.Priority = params.priority
|
rule.Priority = params.priority
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleDel(rule); err != nil {
|
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
return fmt.Errorf("remove routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeAllRules(params ruleParams) error {
|
|
||||||
for {
|
|
||||||
if err := removeRule(params); err != nil {
|
|
||||||
if errors.Is(err, syscall.ENOENT) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNextHop adds the gateway and device to the route.
|
// addNextHop adds the gateway and device to the route.
|
||||||
func addNextHop(addr *string, intf *string, route *netlink.Route) error {
|
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
|
||||||
if addr != nil {
|
if addr.IsValid() {
|
||||||
ip := net.ParseIP(*addr)
|
route.Gw = addr.AsSlice()
|
||||||
if ip == nil {
|
if intf == "" {
|
||||||
return fmt.Errorf("parsing address %s failed", *addr)
|
intf = addr.Zone()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
route.Gw = ip
|
if intf != "" {
|
||||||
}
|
link, err := netlink.LinkByName(intf)
|
||||||
|
|
||||||
if intf != nil {
|
|
||||||
link, err := netlink.LinkByName(*intf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set interface %s: %w", *intf, err)
|
return fmt.Errorf("set interface %s: %w", intf, err)
|
||||||
}
|
}
|
||||||
route.LinkIndex = link.Attrs().Index
|
route.LinkIndex = link.Attrs().Index
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getAddressFamily(prefix netip.Prefix) int {
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
return netlink.FAMILY_V4
|
||||||
|
}
|
||||||
|
return netlink.FAMILY_V6
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupSysctl configures sysctl settings for RP filtering and source validation.
|
||||||
|
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
|
||||||
|
keys := map[string]int{}
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
|
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[srcValidMarkPath] = oldVal
|
||||||
|
}
|
||||||
|
|
||||||
|
oldVal, err = setSysctl(rpFilterPath, 2, true)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[rpFilterPath] = oldVal
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, intf := range interfaces {
|
||||||
|
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||||
|
oldVal, err := setSysctl(i, 2, true)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[i] = oldVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||||
|
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||||
|
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||||
|
currentValue, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||||
|
if err != nil && len(currentValue) > 0 {
|
||||||
|
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||||
|
return currentV, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gosec
|
||||||
|
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||||
|
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||||
|
|
||||||
|
return currentV, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupSysctl(originalSettings map[string]int) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
|
for key, value := range originalSettings {
|
||||||
|
_, err := setSysctl(key, value, false)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,34 +6,38 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gopacket/gopacket"
|
|
||||||
"github.com/gopacket/gopacket/layers"
|
|
||||||
"github.com/gopacket/gopacket/pcap"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketExpectation struct {
|
var expectedVPNint = "wgtest0"
|
||||||
SrcIP net.IP
|
var expectedLoopbackInt = "lo"
|
||||||
DstIP net.IP
|
var expectedExternalInt = "dummyext0"
|
||||||
SrcPort int
|
var expectedInternalInt = "dummyint0"
|
||||||
DstPort int
|
|
||||||
UDP bool
|
func init() {
|
||||||
TCP bool
|
testCases = append(testCases, []testCase{
|
||||||
|
{
|
||||||
|
name: "To more specific route without custom dialer via physical interface",
|
||||||
|
destination: "10.10.0.2:53",
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To more specific route (local) without custom dialer via physical interface",
|
||||||
|
destination: "127.0.10.1:53",
|
||||||
|
expectedInterface: expectedLoopbackInt,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
||||||
|
},
|
||||||
|
}...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEntryExists(t *testing.T) {
|
func TestEntryExists(t *testing.T) {
|
||||||
@@ -92,157 +96,7 @@ func TestEntryExists(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutingWithTables(t *testing.T) {
|
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
destination string
|
|
||||||
captureInterface string
|
|
||||||
dialer *net.Dialer
|
|
||||||
packetExpectation PacketExpectation
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "To external host without fwmark via vpn",
|
|
||||||
destination: "192.0.2.1:53",
|
|
||||||
captureInterface: "wgtest0",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To external host with fwmark via physical interface",
|
|
||||||
destination: "192.0.2.1:53",
|
|
||||||
captureInterface: "dummyext0",
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route with fwmark via physical interface",
|
|
||||||
destination: "10.0.0.1:53",
|
|
||||||
captureInterface: "dummyint0",
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence
|
|
||||||
destination: "10.0.0.1:53",
|
|
||||||
captureInterface: "dummyint0",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To unique vpn route with fwmark via physical interface",
|
|
||||||
destination: "172.16.0.1:53",
|
|
||||||
captureInterface: "dummyext0",
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To unique vpn route without fwmark via vpn",
|
|
||||||
destination: "172.16.0.1:53",
|
|
||||||
captureInterface: "wgtest0",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To more specific route without fwmark via vpn interface",
|
|
||||||
destination: "10.10.0.1:53",
|
|
||||||
captureInterface: "dummyint0",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To more specific route (local) without fwmark via physical interface",
|
|
||||||
destination: "127.0.10.1:53",
|
|
||||||
captureInterface: "lo",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
wgIface, _, _ := setupTestEnv(t)
|
|
||||||
|
|
||||||
// default route exists in main table and vpn table
|
|
||||||
err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
||||||
|
|
||||||
// 10.0.0.0/8 route exists in main table and vpn table
|
|
||||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
||||||
|
|
||||||
// 10.10.0.0/24 more specific route exists in vpn table
|
|
||||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
||||||
|
|
||||||
// 127.0.10.0/24 more specific route exists in vpn table
|
|
||||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
||||||
|
|
||||||
// unique route in vpn table
|
|
||||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
||||||
|
|
||||||
filter := createBPFFilter(tc.destination)
|
|
||||||
handle := startPacketCapture(t, tc.captureInterface, filter)
|
|
||||||
|
|
||||||
sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer)
|
|
||||||
|
|
||||||
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
|
||||||
packet, err := packetSource.NextPacket()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
verifyPacket(t, packet, tc.packetExpectation)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
|
||||||
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
|
|
||||||
|
|
||||||
ip, ok := ipLayer.(*layers.IPv4)
|
|
||||||
require.True(t, ok, "Failed to cast to IPv4 layer")
|
|
||||||
|
|
||||||
// Convert both source and destination IP addresses to 16-byte representation
|
|
||||||
expectedSrcIP := exp.SrcIP.To16()
|
|
||||||
actualSrcIP := ip.SrcIP.To16()
|
|
||||||
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
|
|
||||||
|
|
||||||
expectedDstIP := exp.DstIP.To16()
|
|
||||||
actualDstIP := ip.DstIP.To16()
|
|
||||||
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
|
|
||||||
|
|
||||||
if exp.UDP {
|
|
||||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
||||||
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
|
|
||||||
|
|
||||||
udp, ok := udpLayer.(*layers.UDP)
|
|
||||||
require.True(t, ok, "Failed to cast to UDP layer")
|
|
||||||
|
|
||||||
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
|
|
||||||
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
if exp.TCP {
|
|
||||||
tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
|
||||||
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
|
|
||||||
|
|
||||||
tcp, ok := tcpLayer.(*layers.TCP)
|
|
||||||
require.True(t, ok, "Failed to cast to TCP layer")
|
|
||||||
|
|
||||||
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
|
|
||||||
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy {
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
|
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
|
||||||
@@ -264,34 +118,51 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return dummy
|
t.Cleanup(func() {
|
||||||
|
err := netlink.LinkDel(dummy)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
return dummy.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
|
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
|
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Handle existing routes with metric 0
|
||||||
|
var originalNexthop net.IP
|
||||||
|
var originalLinkIndex int
|
||||||
if dstIPNet.String() == "0.0.0.0/0" {
|
if dstIPNet.String() == "0.0.0.0/0" {
|
||||||
gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4)
|
var err error
|
||||||
if err != nil {
|
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
t.Logf("Failed to fetch original gateway: %v", err)
|
t.Logf("Failed to fetch original gateway: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle existing routes with metric 0
|
if originalNexthop != nil {
|
||||||
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
|
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
|
||||||
if err == nil {
|
switch {
|
||||||
|
case err != nil && !errors.Is(err, syscall.ESRCH):
|
||||||
|
t.Logf("Failed to delete route: %v", err)
|
||||||
|
case err == nil:
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0})
|
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0})
|
||||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||||
t.Fatalf("Failed to add route: %v", err)
|
t.Fatalf("Failed to add route: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
} else if !errors.Is(err, syscall.ESRCH) {
|
default:
|
||||||
t.Logf("Failed to delete route: %v", err)
|
t.Logf("Failed to delete route: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
link, err := netlink.LinkByName(intf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
linkIndex := link.Attrs().Index
|
||||||
|
|
||||||
route := &netlink.Route{
|
route := &netlink.Route{
|
||||||
Dst: dstIPNet,
|
Dst: dstIPNet,
|
||||||
@@ -307,9 +178,9 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
|
|||||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||||
t.Fatalf("Failed to add route: %v", err)
|
t.Fatalf("Failed to add route: %v", err)
|
||||||
}
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchOriginalGateway returns the original gateway IP address and the interface index.
|
|
||||||
func fetchOriginalGateway(family int) (net.IP, int, error) {
|
func fetchOriginalGateway(family int) (net.IP, int, error) {
|
||||||
routes, err := netlink.RouteList(nil, family)
|
routes, err := netlink.RouteList(nil, family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -317,153 +188,20 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Dst == nil {
|
if route.Dst == nil && route.Priority == 0 {
|
||||||
return route.Gw, route.LinkIndex, nil
|
return route.Gw, route.LinkIndex, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, 0, fmt.Errorf("default route not found")
|
return nil, 0, ErrRouteNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) {
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
|
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
|
||||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index)
|
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
||||||
|
|
||||||
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
|
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
|
||||||
addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index)
|
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := netlink.LinkDel(defaultDummy)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = netlink.LinkDel(otherDummy)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
return defaultDummy.Name, otherDummy.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
|
||||||
require.NoError(t, err, "should create testing WireGuard interface")
|
|
||||||
|
|
||||||
err = wgInterface.Create()
|
|
||||||
require.NoError(t, err, "should create testing WireGuard interface")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
wgInterface.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
return wgInterface
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t)
|
|
||||||
|
|
||||||
wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, wgIface.Close())
|
|
||||||
})
|
|
||||||
|
|
||||||
err := setupRouting()
|
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, cleanupRouting())
|
|
||||||
})
|
|
||||||
|
|
||||||
return wgIface, defaultDummy, otherDummy
|
|
||||||
}
|
|
||||||
|
|
||||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
inactive, err := pcap.NewInactiveHandle(intf)
|
|
||||||
require.NoError(t, err, "Failed to create inactive pcap handle")
|
|
||||||
defer inactive.CleanUp()
|
|
||||||
|
|
||||||
err = inactive.SetSnapLen(1600)
|
|
||||||
require.NoError(t, err, "Failed to set snap length on inactive handle")
|
|
||||||
|
|
||||||
err = inactive.SetTimeout(time.Second * 10)
|
|
||||||
require.NoError(t, err, "Failed to set timeout on inactive handle")
|
|
||||||
|
|
||||||
err = inactive.SetImmediateMode(true)
|
|
||||||
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
|
|
||||||
|
|
||||||
handle, err := inactive.Activate()
|
|
||||||
require.NoError(t, err, "Failed to activate pcap handle")
|
|
||||||
t.Cleanup(handle.Close)
|
|
||||||
|
|
||||||
err = handle.SetBPFFilter(filter)
|
|
||||||
require.NoError(t, err, "Failed to set BPF filter")
|
|
||||||
|
|
||||||
return handle
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if dialer == nil {
|
|
||||||
dialer = &net.Dialer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sourcePort != 0 {
|
|
||||||
localUDPAddr := &net.UDPAddr{
|
|
||||||
IP: net.IPv4zero,
|
|
||||||
Port: sourcePort,
|
|
||||||
}
|
|
||||||
dialer.LocalAddr = localUDPAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
msg.Id = dns.Id()
|
|
||||||
msg.RecursionDesired = true
|
|
||||||
msg.Question = []dns.Question{
|
|
||||||
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := dialer.Dial("udp", destination)
|
|
||||||
require.NoError(t, err, "Failed to dial UDP")
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
data, err := msg.Pack()
|
|
||||||
require.NoError(t, err, "Failed to pack DNS message")
|
|
||||||
|
|
||||||
_, err = conn.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "required key not available") {
|
|
||||||
t.Logf("Ignoring WireGuard key error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("Failed to send DNS query: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createBPFFilter(destination string) string {
|
|
||||||
host, port, err := net.SplitHostPort(destination)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
|
|
||||||
}
|
|
||||||
return "udp"
|
|
||||||
}
|
|
||||||
|
|
||||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
|
||||||
return PacketExpectation{
|
|
||||||
SrcIP: net.ParseIP(srcIP),
|
|
||||||
DstIP: net.ParseIP(dstIP),
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
UDP: true,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,148 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
//nolint:unused
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/libp2p/go-netroute"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
var errRouteNotFound = fmt.Errorf("route not found")
|
|
||||||
|
|
||||||
func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|
||||||
defaultGateway, err := getExistingRIBRouteGateway(defaultv4)
|
|
||||||
if err != nil && !errors.Is(err, errRouteNotFound) {
|
|
||||||
return fmt.Errorf("get existing route gateway: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := netip.MustParseAddr(defaultGateway.String())
|
|
||||||
|
|
||||||
if !prefix.Contains(addr) {
|
|
||||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayPrefix := netip.PrefixFrom(addr, 32)
|
|
||||||
|
|
||||||
ok, err := existsInRouteTable(gatewayPrefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
|
|
||||||
if err != nil && !errors.Is(err, errRouteNotFound) {
|
|
||||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
|
||||||
}
|
|
||||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
|
||||||
return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
|
||||||
ok, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("exists in route table: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sub range: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
err := genericAddRouteForCurrentDefaultGateway(prefix)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return genericAddToRouteTable(prefix, addr, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
|
||||||
return genericRemoveFromRouteTable(prefix, addr, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error {
|
|
||||||
cmd := exec.Command("route", "add", prefix.String(), addr)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("add route: %w", err)
|
|
||||||
}
|
|
||||||
log.Debugf(string(out))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error {
|
|
||||||
args := []string{"delete", prefix.String()}
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
args = append(args, addr)
|
|
||||||
}
|
|
||||||
cmd := exec.Command("route", args...)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove route: %w", err)
|
|
||||||
}
|
|
||||||
log.Debugf(string(out))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
|
|
||||||
r, err := netroute.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("new netroute: %w", err)
|
|
||||||
}
|
|
||||||
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Getting routes returned an error: %v", err)
|
|
||||||
return nil, errRouteNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
if gateway == nil {
|
|
||||||
return preferredSrc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return gateway, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := getRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get routes from table: %w", err)
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute == prefix {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := getRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get routes from table: %w", err)
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
@@ -1,22 +1,23 @@
|
|||||||
//go:build !linux || android
|
//go:build !linux && !ios
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupRouting() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanupRouting() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func enableIPForwarding() error {
|
||||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
return genericAddVPNRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
return genericRemoveVPNRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIsSubRange(t *testing.T) {
|
|
||||||
addresses, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var subRangeAddressPrefixes []netip.Prefix
|
|
||||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
|
||||||
for _, address := range addresses {
|
|
||||||
p := netip.MustParsePrefix(address.String())
|
|
||||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
|
||||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
|
||||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
|
||||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range subRangeAddressPrefixes {
|
|
||||||
isSubRangePrefix, err := isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
|
||||||
}
|
|
||||||
if !isSubRangePrefix {
|
|
||||||
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range nonSubRangeAddressPrefixes {
|
|
||||||
isSubRangePrefix, err := isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
|
||||||
}
|
|
||||||
if isSubRangePrefix {
|
|
||||||
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExistsInRouteTable(t *testing.T) {
|
|
||||||
require.NoError(t, setupRouting())
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, cleanupRouting())
|
|
||||||
})
|
|
||||||
|
|
||||||
addresses, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var addressPrefixes []netip.Prefix
|
|
||||||
for _, address := range addresses {
|
|
||||||
p := netip.MustParsePrefix(address.String())
|
|
||||||
if p.Addr().Is4() {
|
|
||||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range addressPrefixes {
|
|
||||||
exists, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
t.Fatalf("address %s should exist in route table", prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
//go:build !android
|
//go:build !android && !ios
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -22,47 +22,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
type dialer interface {
|
||||||
t.Helper()
|
Dial(network, address string) (net.Conn, error)
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String())
|
|
||||||
require.NoError(t, err, "getOutgoingInterfaceLinux should not return error")
|
|
||||||
if invert {
|
|
||||||
require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface")
|
|
||||||
} else {
|
|
||||||
require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixGateway, err := getExistingRIBRouteGateway(prefix)
|
|
||||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
|
||||||
if invert {
|
|
||||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOutgoingInterfaceLinux(destination string) (string, error) {
|
|
||||||
cmd := exec.Command("ip", "route", "get", destination)
|
|
||||||
output, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("executing ip route get: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return parseOutgoingInterface(string(output)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseOutgoingInterface(routeGetOutput string) string {
|
|
||||||
fields := strings.Fields(routeGetOutput)
|
|
||||||
for i, field := range fields {
|
|
||||||
if field == "dev" && i+1 < len(fields) {
|
|
||||||
return fields[i+1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddRemoveRoutes(t *testing.T) {
|
func TestAddRemoveRoutes(t *testing.T) {
|
||||||
@@ -99,14 +61,14 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
|
|
||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
_, _, err = setupRouting(nil, wgInterface)
|
||||||
require.NoError(t, setupRouting())
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, cleanupRouting())
|
assert.NoError(t, cleanupRouting())
|
||||||
})
|
})
|
||||||
|
|
||||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
|
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||||
|
|
||||||
if testCase.shouldRouteToWireguard {
|
if testCase.shouldRouteToWireguard {
|
||||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
||||||
@@ -116,13 +78,13 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
exists, err := existsInRouteTable(testCase.prefix)
|
exists, err := existsInRouteTable(testCase.prefix)
|
||||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||||
if exists && testCase.shouldRouteToWireguard {
|
if exists && testCase.shouldRouteToWireguard {
|
||||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
|
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
|
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||||
|
|
||||||
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
||||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
require.NoError(t, err, "getNextHop should not return err")
|
||||||
|
|
||||||
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if testCase.shouldBeRemoved {
|
if testCase.shouldBeRemoved {
|
||||||
@@ -135,12 +97,12 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetExistingRIBRouteGateway(t *testing.T) {
|
func TestGetNextHop(t *testing.T) {
|
||||||
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
}
|
}
|
||||||
if gateway == nil {
|
if !gateway.IsValid() {
|
||||||
t.Fatal("should return a gateway")
|
t.Fatal("should return a gateway")
|
||||||
}
|
}
|
||||||
addresses, err := net.InterfaceAddrs()
|
addresses, err := net.InterfaceAddrs()
|
||||||
@@ -162,11 +124,11 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localIP, err := getExistingRIBRouteGateway(testingPrefix)
|
localIP, _, err := getNextHop(testingPrefix.Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error: ", err)
|
t.Fatal("shouldn't return error: ", err)
|
||||||
}
|
}
|
||||||
if localIP == nil {
|
if !localIP.IsValid() {
|
||||||
t.Fatal("should return a gateway for local network")
|
t.Fatal("should return a gateway for local network")
|
||||||
}
|
}
|
||||||
if localIP.String() == gateway.String() {
|
if localIP.String() == gateway.String() {
|
||||||
@@ -177,8 +139,8 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||||
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
t.Log("defaultGateway: ", defaultGateway)
|
t.Log("defaultGateway: ", defaultGateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
@@ -238,21 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
|
||||||
require.NoError(t, setupRouting())
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, cleanupRouting())
|
|
||||||
})
|
|
||||||
|
|
||||||
MockAddr := wgInterface.Address().IP.String()
|
|
||||||
|
|
||||||
// Prepare the environment
|
// Prepare the environment
|
||||||
if testCase.preExistingPrefix.IsValid() {
|
if testCase.preExistingPrefix.IsValid() {
|
||||||
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name())
|
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the route
|
// Add the route
|
||||||
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name())
|
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err when adding route")
|
require.NoError(t, err, "should not return err when adding route")
|
||||||
|
|
||||||
if testCase.shouldAddRoute {
|
if testCase.shouldAddRoute {
|
||||||
@@ -262,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
require.True(t, ok, "route should exist")
|
require.True(t, ok, "route should exist")
|
||||||
|
|
||||||
// remove route again if added
|
// remove route again if added
|
||||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name())
|
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,11 +227,176 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
t.Log("Buffer string: ", buf.String())
|
t.Log("Buffer string: ", buf.String())
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
|
|
||||||
// Linux uses a separate routing table, so the route can exist in both tables.
|
if !strings.Contains(buf.String(), "because it already exists") {
|
||||||
// The main routing table takes precedence over the wireguard routing table.
|
|
||||||
if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" {
|
|
||||||
require.False(t, ok, "route should not exist")
|
require.False(t, ok, "route should not exist")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsSubRange(t *testing.T) {
|
||||||
|
addresses, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var subRangeAddressPrefixes []netip.Prefix
|
||||||
|
var nonSubRangeAddressPrefixes []netip.Prefix
|
||||||
|
for _, address := range addresses {
|
||||||
|
p := netip.MustParsePrefix(address.String())
|
||||||
|
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
||||||
|
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
||||||
|
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
||||||
|
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range subRangeAddressPrefixes {
|
||||||
|
isSubRangePrefix, err := isSubRange(prefix)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||||
|
}
|
||||||
|
if !isSubRangePrefix {
|
||||||
|
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range nonSubRangeAddressPrefixes {
|
||||||
|
isSubRangePrefix, err := isSubRange(prefix)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||||
|
}
|
||||||
|
if isSubRangePrefix {
|
||||||
|
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistsInRouteTable(t *testing.T) {
|
||||||
|
addresses, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var addressPrefixes []netip.Prefix
|
||||||
|
for _, address := range addresses {
|
||||||
|
p := netip.MustParsePrefix(address.String())
|
||||||
|
if p.Addr().Is6() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||||
|
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||||
|
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range addressPrefixes {
|
||||||
|
exists, err := existsInRouteTable(prefix)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("address %s should exist in route table", prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newNet, err := stdnet.NewNet()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||||
|
require.NoError(t, err, "should create testing WireGuard interface")
|
||||||
|
|
||||||
|
err = wgInterface.Create()
|
||||||
|
require.NoError(t, err, "should create testing WireGuard interface")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
wgInterface.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return wgInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestEnv(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
setupDummyInterfacesAndRoutes(t)
|
||||||
|
|
||||||
|
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, wgIface.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
_, _, err := setupRouting(nil, wgIface)
|
||||||
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, cleanupRouting())
|
||||||
|
})
|
||||||
|
|
||||||
|
// default route exists in main table and vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 10.0.0.0/8 route exists in main table and vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 10.10.0.0/24 more specific route exists in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 127.0.10.0/24 more specific route exists in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// unique route in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||||
|
t.Helper()
|
||||||
|
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixGateway, _, err := getNextHop(prefix.Addr())
|
||||||
|
require.NoError(t, err, "getNextHop should not return err")
|
||||||
|
if invert {
|
||||||
|
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
234
client/internal/routemanager/systemops_unix_test.go
Normal file
234
client/internal/routemanager/systemops_unix_test.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gopacket/gopacket"
|
||||||
|
"github.com/gopacket/gopacket/layers"
|
||||||
|
"github.com/gopacket/gopacket/pcap"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketExpectation struct {
|
||||||
|
SrcIP net.IP
|
||||||
|
DstIP net.IP
|
||||||
|
SrcPort int
|
||||||
|
DstPort int
|
||||||
|
UDP bool
|
||||||
|
TCP bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
destination string
|
||||||
|
expectedInterface string
|
||||||
|
dialer dialer
|
||||||
|
expectedPacket PacketExpectation
|
||||||
|
}
|
||||||
|
|
||||||
|
var testCases = []testCase{
|
||||||
|
{
|
||||||
|
name: "To external host without custom dialer via vpn",
|
||||||
|
destination: "192.0.2.1:53",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To external host with custom dialer via physical interface",
|
||||||
|
destination: "192.0.2.1:53",
|
||||||
|
expectedInterface: expectedExternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route with custom dialer via physical interface",
|
||||||
|
destination: "10.0.0.2:53",
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||||
|
destination: "10.0.0.2:53",
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To unique vpn route with custom dialer via physical interface",
|
||||||
|
destination: "172.16.0.2:53",
|
||||||
|
expectedInterface: expectedExternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To unique vpn route without custom dialer via vpn",
|
||||||
|
destination: "172.16.0.2:53",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouting(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
setupTestEnv(t)
|
||||||
|
|
||||||
|
filter := createBPFFilter(tc.destination)
|
||||||
|
handle := startPacketCapture(t, tc.expectedInterface, filter)
|
||||||
|
|
||||||
|
sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer)
|
||||||
|
|
||||||
|
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
||||||
|
packet, err := packetSource.NextPacket()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
verifyPacket(t, packet, tc.expectedPacket)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||||
|
return PacketExpectation{
|
||||||
|
SrcIP: net.ParseIP(srcIP),
|
||||||
|
DstIP: net.ParseIP(dstIP),
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
UDP: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
inactive, err := pcap.NewInactiveHandle(intf)
|
||||||
|
require.NoError(t, err, "Failed to create inactive pcap handle")
|
||||||
|
defer inactive.CleanUp()
|
||||||
|
|
||||||
|
err = inactive.SetSnapLen(1600)
|
||||||
|
require.NoError(t, err, "Failed to set snap length on inactive handle")
|
||||||
|
|
||||||
|
err = inactive.SetTimeout(time.Second * 10)
|
||||||
|
require.NoError(t, err, "Failed to set timeout on inactive handle")
|
||||||
|
|
||||||
|
err = inactive.SetImmediateMode(true)
|
||||||
|
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
|
||||||
|
|
||||||
|
handle, err := inactive.Activate()
|
||||||
|
require.NoError(t, err, "Failed to activate pcap handle")
|
||||||
|
t.Cleanup(handle.Close)
|
||||||
|
|
||||||
|
err = handle.SetBPFFilter(filter)
|
||||||
|
require.NoError(t, err, "Failed to set BPF filter")
|
||||||
|
|
||||||
|
return handle
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if dialer == nil {
|
||||||
|
dialer = &net.Dialer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sourcePort != 0 {
|
||||||
|
localUDPAddr := &net.UDPAddr{
|
||||||
|
IP: net.IPv4zero,
|
||||||
|
Port: sourcePort,
|
||||||
|
}
|
||||||
|
switch dialer := dialer.(type) {
|
||||||
|
case *nbnet.Dialer:
|
||||||
|
dialer.LocalAddr = localUDPAddr
|
||||||
|
case *net.Dialer:
|
||||||
|
dialer.LocalAddr = localUDPAddr
|
||||||
|
default:
|
||||||
|
t.Fatal("Unsupported dialer type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.Id = dns.Id()
|
||||||
|
msg.RecursionDesired = true
|
||||||
|
msg.Question = []dns.Question{
|
||||||
|
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialer.Dial("udp", destination)
|
||||||
|
require.NoError(t, err, "Failed to dial UDP")
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
data, err := msg.Pack()
|
||||||
|
require.NoError(t, err, "Failed to pack DNS message")
|
||||||
|
|
||||||
|
_, err = conn.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "required key not available") {
|
||||||
|
t.Logf("Ignoring WireGuard key error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("Failed to send DNS query: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createBPFFilter(destination string) string {
|
||||||
|
host, port, err := net.SplitHostPort(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
|
||||||
|
}
|
||||||
|
return "udp"
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
||||||
|
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
|
||||||
|
|
||||||
|
ip, ok := ipLayer.(*layers.IPv4)
|
||||||
|
require.True(t, ok, "Failed to cast to IPv4 layer")
|
||||||
|
|
||||||
|
// Convert both source and destination IP addresses to 16-byte representation
|
||||||
|
expectedSrcIP := exp.SrcIP.To16()
|
||||||
|
actualSrcIP := ip.SrcIP.To16()
|
||||||
|
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
|
||||||
|
|
||||||
|
expectedDstIP := exp.DstIP.To16()
|
||||||
|
actualDstIP := ip.DstIP.To16()
|
||||||
|
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
|
||||||
|
|
||||||
|
if exp.UDP {
|
||||||
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||||
|
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
|
||||||
|
|
||||||
|
udp, ok := udpLayer.(*layers.UDP)
|
||||||
|
require.True(t, ok, "Failed to cast to UDP layer")
|
||||||
|
|
||||||
|
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
|
||||||
|
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if exp.TCP {
|
||||||
|
tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
||||||
|
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
|
||||||
|
|
||||||
|
tcp, ok := tcpLayer.(*layers.TCP)
|
||||||
|
require.True(t, ok, "Failed to cast to TCP layer")
|
||||||
|
|
||||||
|
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
|
||||||
|
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,9 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/yusufpapurcu/wmi"
|
"github.com/yusufpapurcu/wmi"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Win32_IP4RouteTable struct {
|
type Win32_IP4RouteTable struct {
|
||||||
@@ -16,6 +21,16 @@ type Win32_IP4RouteTable struct {
|
|||||||
Mask string
|
Mask string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var routeManager *RouteManager
|
||||||
|
|
||||||
|
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRouting() error {
|
||||||
|
return cleanupRoutingWithRouteManager(routeManager)
|
||||||
|
}
|
||||||
|
|
||||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
var routes []Win32_IP4RouteTable
|
var routes []Win32_IP4RouteTable
|
||||||
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
||||||
@@ -48,10 +63,85 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
|
||||||
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
|
destinationPrefix := prefix.String()
|
||||||
|
psCmd := "New-NetRoute"
|
||||||
|
|
||||||
|
addressFamily := "IPv4"
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
addressFamily = "IPv6"
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
script := fmt.Sprintf(
|
||||||
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
|
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
|
||||||
|
psCmd, addressFamily, destinationPrefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
if intfIdx != "" {
|
||||||
|
script = fmt.Sprintf(
|
||||||
|
`%s -InterfaceIndex %s`, script, intfIdx,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
script = fmt.Sprintf(
|
||||||
|
`%s -InterfaceAlias "%s"`, script, intf,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nexthop.IsValid() {
|
||||||
|
script = fmt.Sprintf(
|
||||||
|
`%s -NextHop "%s"`, script, nexthop,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||||
|
log.Tracef("PowerShell %s: %s", script, string(out))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("PowerShell add route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
||||||
|
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
|
||||||
|
|
||||||
|
out, err := exec.Command("route", args...).CombinedOutput()
|
||||||
|
|
||||||
|
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route add: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
var intfIdx string
|
||||||
|
if nexthop.Zone() != "" {
|
||||||
|
intfIdx = nexthop.Zone()
|
||||||
|
nexthop.WithZone("")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Powershell doesn't support adding routes without an interface but allows to add interface by name
|
||||||
|
if intf != "" || intfIdx != "" {
|
||||||
|
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
|
||||||
|
}
|
||||||
|
return addRouteCmd(prefix, nexthop, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
||||||
|
args := []string{"delete", prefix.String()}
|
||||||
|
if nexthop.IsValid() {
|
||||||
|
nexthop.WithZone("")
|
||||||
|
args = append(args, nexthop.Unmap().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command("route", args...).CombinedOutput()
|
||||||
|
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("remove route: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
289
client/internal/routemanager/systemops_windows_test.go
Normal file
289
client/internal/routemanager/systemops_windows_test.go
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
var expectedExtInt = "Ethernet1"
|
||||||
|
|
||||||
|
type RouteInfo struct {
|
||||||
|
NextHop string `json:"nexthop"`
|
||||||
|
InterfaceAlias string `json:"interfacealias"`
|
||||||
|
RouteMetric int `json:"routemetric"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FindNetRouteOutput struct {
|
||||||
|
IPAddress string `json:"IPAddress"`
|
||||||
|
InterfaceIndex int `json:"InterfaceIndex"`
|
||||||
|
InterfaceAlias string `json:"InterfaceAlias"`
|
||||||
|
AddressFamily int `json:"AddressFamily"`
|
||||||
|
NextHop string `json:"NextHop"`
|
||||||
|
DestinationPrefix string `json:"DestinationPrefix"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
destination string
|
||||||
|
expectedSourceIP string
|
||||||
|
expectedDestPrefix string
|
||||||
|
expectedNextHop string
|
||||||
|
expectedInterface string
|
||||||
|
dialer dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedVPNint = "wgtest0"
|
||||||
|
|
||||||
|
var testCases = []testCase{
|
||||||
|
{
|
||||||
|
name: "To external host without custom dialer via vpn",
|
||||||
|
destination: "192.0.2.1:53",
|
||||||
|
expectedSourceIP: "100.64.0.1",
|
||||||
|
expectedDestPrefix: "128.0.0.0/1",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "wgtest0",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To external host with custom dialer via physical interface",
|
||||||
|
destination: "192.0.2.1:53",
|
||||||
|
expectedDestPrefix: "192.0.2.1/32",
|
||||||
|
expectedInterface: expectedExtInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route with custom dialer via physical interface",
|
||||||
|
destination: "10.0.0.2:53",
|
||||||
|
expectedDestPrefix: "10.0.0.2/32",
|
||||||
|
expectedInterface: expectedExtInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||||
|
destination: "10.0.0.2:53",
|
||||||
|
expectedSourceIP: "10.0.0.1",
|
||||||
|
expectedDestPrefix: "10.0.0.0/8",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To unique vpn route with custom dialer via physical interface",
|
||||||
|
destination: "172.16.0.2:53",
|
||||||
|
expectedDestPrefix: "172.16.0.2/32",
|
||||||
|
expectedInterface: expectedExtInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To unique vpn route without custom dialer via vpn",
|
||||||
|
destination: "172.16.0.2:53",
|
||||||
|
expectedSourceIP: "100.64.0.1",
|
||||||
|
expectedDestPrefix: "172.16.0.0/12",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "wgtest0",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To more specific route without custom dialer via vpn interface",
|
||||||
|
destination: "10.10.0.2:53",
|
||||||
|
expectedSourceIP: "100.64.0.1",
|
||||||
|
expectedDestPrefix: "10.10.0.0/24",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "wgtest0",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To more specific route (local) without custom dialer via physical interface",
|
||||||
|
destination: "127.0.10.2:53",
|
||||||
|
expectedSourceIP: "10.0.0.1",
|
||||||
|
expectedDestPrefix: "127.0.0.0/8",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouting(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
setupTestEnv(t)
|
||||||
|
|
||||||
|
route, err := fetchOriginalGateway()
|
||||||
|
require.NoError(t, err, "Failed to fetch original gateway")
|
||||||
|
ip, err := fetchInterfaceIP(route.InterfaceAlias)
|
||||||
|
require.NoError(t, err, "Failed to fetch interface IP")
|
||||||
|
|
||||||
|
output := testRoute(t, tc.destination, tc.dialer)
|
||||||
|
if tc.expectedInterface == expectedExtInt {
|
||||||
|
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
|
||||||
|
} else {
|
||||||
|
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchInterfaceIP fetches the IPv4 address of the specified interface.
|
||||||
|
func fetchInterfaceIP(interfaceAlias string) (string, error) {
|
||||||
|
script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias)
|
||||||
|
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := strings.TrimSpace(string(out))
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", destination)
|
||||||
|
require.NoError(t, err, "Failed to dial destination")
|
||||||
|
defer func() {
|
||||||
|
err := conn.Close()
|
||||||
|
assert.NoError(t, err, "Failed to close connection")
|
||||||
|
}()
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(destination)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
|
||||||
|
|
||||||
|
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||||
|
require.NoError(t, err, "Failed to execute Find-NetRoute")
|
||||||
|
|
||||||
|
var outputs []FindNetRouteOutput
|
||||||
|
err = json.Unmarshal(out, &outputs)
|
||||||
|
require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute")
|
||||||
|
|
||||||
|
require.Greater(t, len(outputs), 0, "No route found for destination")
|
||||||
|
combinedOutput := combineOutputs(outputs)
|
||||||
|
|
||||||
|
return combinedOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
|
||||||
|
require.NoError(t, err)
|
||||||
|
subnetMaskSize, _ := ipNet.Mask.Size()
|
||||||
|
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
|
||||||
|
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
|
||||||
|
|
||||||
|
// Wait for the IP address to be applied
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err = waitForIPAddress(ctx, interfaceName, ip.String())
|
||||||
|
require.NoError(t, err, "IP address not applied within timeout")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
|
||||||
|
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
|
||||||
|
})
|
||||||
|
|
||||||
|
return interfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchOriginalGateway() (*RouteInfo, error) {
|
||||||
|
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var routeInfo RouteInfo
|
||||||
|
err = json.Unmarshal(output, &routeInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JSON output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &routeInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch")
|
||||||
|
assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch")
|
||||||
|
assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch")
|
||||||
|
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||||
|
for _, ip := range ipAddresses {
|
||||||
|
if strings.TrimSpace(ip) == expectedIPAddress {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
|
||||||
|
var combined FindNetRouteOutput
|
||||||
|
|
||||||
|
for _, output := range outputs {
|
||||||
|
if output.IPAddress != "" {
|
||||||
|
combined.IPAddress = output.IPAddress
|
||||||
|
}
|
||||||
|
if output.InterfaceIndex != 0 {
|
||||||
|
combined.InterfaceIndex = output.InterfaceIndex
|
||||||
|
}
|
||||||
|
if output.InterfaceAlias != "" {
|
||||||
|
combined.InterfaceAlias = output.InterfaceAlias
|
||||||
|
}
|
||||||
|
if output.AddressFamily != 0 {
|
||||||
|
combined.AddressFamily = output.AddressFamily
|
||||||
|
}
|
||||||
|
if output.NextHop != "" {
|
||||||
|
combined.NextHop = output.NextHop
|
||||||
|
}
|
||||||
|
if output.DestinationPrefix != "" {
|
||||||
|
combined.DestinationPrefix = output.DestinationPrefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &combined
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
|
||||||
|
}
|
||||||
@@ -1,10 +1,8 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,7 +23,7 @@ func (pl portLookup) searchFreePort() (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pl portLookup) tryToBind(port int) error {
|
func (pl portLookup) tryToBind(port int) error {
|
||||||
l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port))
|
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||||
@@ -29,7 +30,7 @@ type WGEBPFProxy struct {
|
|||||||
turnConnMutex sync.Mutex
|
turnConnMutex sync.Mutex
|
||||||
|
|
||||||
rawConn net.PacketConn
|
rawConn net.PacketConn
|
||||||
conn *net.UDPConn
|
conn transport.UDPConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||||
@@ -67,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
IP: net.ParseIP("127.0.0.1"),
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
p.conn, err = nbnet.ListenUDP("udp", &addr)
|
conn, err := nbnet.ListenUDP("udp", &addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cErr := p.Free()
|
cErr := p.Free()
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
@@ -75,6 +76,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
p.conn = conn
|
||||||
|
|
||||||
go p.proxyToRemote()
|
go p.proxyToRemote()
|
||||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||||
@@ -228,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set the fwmark on the socket.
|
// Set the fwmark on the socket.
|
||||||
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
|
err = nbnet.SetSocketOpt(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
|||||||
BytesRx: peerState.BytesRx,
|
BytesRx: peerState.BytesRx,
|
||||||
BytesTx: peerState.BytesTx,
|
BytesTx: peerState.BytesTx,
|
||||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||||
Routes: maps.Keys(peerState.Routes),
|
Routes: maps.Keys(peerState.GetRoutes()),
|
||||||
Latency: durationpb.New(peerState.Latency),
|
Latency: durationpb.New(peerState.Latency),
|
||||||
}
|
}
|
||||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ func Detect(ctx context.Context) string {
|
|||||||
detectDigitalOcean,
|
detectDigitalOcean,
|
||||||
detectGCP,
|
detectGCP,
|
||||||
detectOracle,
|
detectOracle,
|
||||||
detectIBMCloud,
|
|
||||||
detectSoftlayer,
|
|
||||||
detectVultr,
|
detectVultr,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func detectGCP(ctx context.Context) string {
|
func detectGCP(ctx context.Context) string {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://metadata.google.internal", nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
package detect_cloud
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
func detectIBMCloud(ctx context.Context) string {
|
|
||||||
v1ResultChan := make(chan bool, 1)
|
|
||||||
v2ResultChan := make(chan bool, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
v1ResultChan <- detectIBMSecure(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
v2ResultChan <- detectIBM(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
v1Result, v2Result := <-v1ResultChan, <-v2ResultChan
|
|
||||||
|
|
||||||
if v1Result || v2Result {
|
|
||||||
return "IBM Cloud"
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func detectIBMSecure(ctx context.Context) bool {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "PUT", "https://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := hc.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
return resp.StatusCode == http.StatusOK
|
|
||||||
}
|
|
||||||
|
|
||||||
func detectIBM(ctx context.Context) bool {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "PUT", "http://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := hc.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
return resp.StatusCode == http.StatusOK
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
package detect_cloud
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
func detectSoftlayer(ctx context.Context) string {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.service.softlayer.com/rest/v3/SoftLayer_Resource_Metadata/UserMetadata.txt", nil)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := hc.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
// Since SoftLayer was acquired by IBM, we should return "IBM Cloud"
|
|
||||||
return "IBM Cloud"
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
4
go.mod
4
go.mod
@@ -53,14 +53,14 @@ require (
|
|||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||||
github.com/hashicorp/go-version v1.6.0
|
github.com/hashicorp/go-version v1.6.0
|
||||||
github.com/libp2p/go-netroute v0.2.0
|
github.com/libp2p/go-netroute v0.2.1
|
||||||
github.com/magiconair/properties v1.8.5
|
github.com/magiconair/properties v1.8.5
|
||||||
github.com/mattn/go-sqlite3 v1.14.19
|
github.com/mattn/go-sqlite3 v1.14.19
|
||||||
github.com/mdlayher/socket v0.4.1
|
github.com/mdlayher/socket v0.4.1
|
||||||
github.com/miekg/dns v1.1.43
|
github.com/miekg/dns v1.1.43
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
|
|||||||
10
go.sum
10
go.sum
@@ -345,8 +345,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
||||||
github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE=
|
github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU=
|
||||||
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
|
github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ=
|
||||||
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
|
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
|
||||||
github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls=
|
github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls=
|
||||||
github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
|
github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
|
||||||
@@ -383,8 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
|
|||||||
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
||||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
||||||
@@ -659,7 +659,6 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
|
|||||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
|
||||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||||
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
@@ -746,7 +745,6 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgKernelConfigurer struct {
|
type wgKernelConfigurer struct {
|
||||||
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fwmark := nbnet.NetbirdFwmark
|
fwmark := getFwmark()
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
PrivateKey: &key,
|
PrivateKey: &key,
|
||||||
ReplacePeers: true,
|
ReplacePeers: true,
|
||||||
|
|||||||
@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.NetbirdFwmark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ services:
|
|||||||
command: [
|
command: [
|
||||||
"--port", "443",
|
"--port", "443",
|
||||||
"--log-file", "console",
|
"--log-file", "console",
|
||||||
|
"--log-level", "info",
|
||||||
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
||||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ version: "3"
|
|||||||
services:
|
services:
|
||||||
#UI dashboard
|
#UI dashboard
|
||||||
dashboard:
|
dashboard:
|
||||||
image: wiretrustee/dashboard:$NETBIRD_DASHBOARD_TAG
|
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
#ports:
|
#ports:
|
||||||
# - 80:80
|
# - 80:80
|
||||||
|
|||||||
@@ -251,7 +251,7 @@ var (
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
|
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
|
|||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership
|
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
||||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
|
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
|
||||||
var filteredRoutes []*route.Route
|
var filteredRoutes []*route.Route
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
|
|||||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
|
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Id != account.CreatedBy {
|
if user.Role != UserRoleOwner {
|
||||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
||||||
}
|
}
|
||||||
for _, otherUser := range account.Users {
|
for _, otherUser := range account.Users {
|
||||||
@@ -1473,7 +1473,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
|
|||||||
// if domain already has a primary account, add regular user
|
// if domain already has a primary account, add regular user
|
||||||
if domainAcc != nil {
|
if domainAcc != nil {
|
||||||
account = domainAcc
|
account = domainAcc
|
||||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1849,6 +1849,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
||||||
|
log.Debugf("validated peers has been invalidated for account %s", accountID)
|
||||||
updatedAccount, err := am.Store.GetAccount(accountID)
|
updatedAccount, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to get account %s: %v", accountID, err)
|
log.Errorf("failed to get account %s: %v", accountID, err)
|
||||||
@@ -1864,6 +1865,7 @@ func addAllGroup(account *Account) error {
|
|||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
Name: "All",
|
Name: "All",
|
||||||
Issued: nbgroup.GroupIssuedAPI,
|
Issued: nbgroup.GroupIssuedAPI,
|
||||||
|
AccountID: account.Id,
|
||||||
}
|
}
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||||
@@ -1907,7 +1909,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
|||||||
routes := make(map[string]*route.Route)
|
routes := make(map[string]*route.Route)
|
||||||
setupKeys := map[string]*SetupKey{}
|
setupKeys := map[string]*SetupKey{}
|
||||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||||
users[userID] = NewOwnerUser(userID)
|
users[userID] = NewOwnerUser(userID, accountID)
|
||||||
dnsSettings := DNSSettings{
|
dnsSettings := DNSSettings{
|
||||||
DisabledManagementGroups: make([]string, 0),
|
DisabledManagementGroups: make([]string, 0),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,133 +11,134 @@ type Code struct {
|
|||||||
Code string
|
Code string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Existing consts must not be changed, as this will break the compatibility with the existing data
|
||||||
const (
|
const (
|
||||||
// PeerAddedByUser indicates that a user added a new peer to the system
|
// PeerAddedByUser indicates that a user added a new peer to the system
|
||||||
PeerAddedByUser Activity = iota
|
PeerAddedByUser Activity = 0
|
||||||
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
|
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
|
||||||
PeerAddedWithSetupKey
|
PeerAddedWithSetupKey Activity = 1
|
||||||
// UserJoined indicates that a new user joined the account
|
// UserJoined indicates that a new user joined the account
|
||||||
UserJoined
|
UserJoined Activity = 2
|
||||||
// UserInvited indicates that a new user was invited to join the account
|
// UserInvited indicates that a new user was invited to join the account
|
||||||
UserInvited
|
UserInvited Activity = 3
|
||||||
// AccountCreated indicates that a new account has been created
|
// AccountCreated indicates that a new account has been created
|
||||||
AccountCreated
|
AccountCreated Activity = 4
|
||||||
// PeerRemovedByUser indicates that a user removed a peer from the system
|
// PeerRemovedByUser indicates that a user removed a peer from the system
|
||||||
PeerRemovedByUser
|
PeerRemovedByUser Activity = 5
|
||||||
// RuleAdded indicates that a user added a new rule
|
// RuleAdded indicates that a user added a new rule
|
||||||
RuleAdded
|
RuleAdded Activity = 6
|
||||||
// RuleUpdated indicates that a user updated a rule
|
// RuleUpdated indicates that a user updated a rule
|
||||||
RuleUpdated
|
RuleUpdated Activity = 7
|
||||||
// RuleRemoved indicates that a user removed a rule
|
// RuleRemoved indicates that a user removed a rule
|
||||||
RuleRemoved
|
RuleRemoved Activity = 8
|
||||||
// PolicyAdded indicates that a user added a new policy
|
// PolicyAdded indicates that a user added a new policy
|
||||||
PolicyAdded
|
PolicyAdded Activity = 9
|
||||||
// PolicyUpdated indicates that a user updated a policy
|
// PolicyUpdated indicates that a user updated a policy
|
||||||
PolicyUpdated
|
PolicyUpdated Activity = 10
|
||||||
// PolicyRemoved indicates that a user removed a policy
|
// PolicyRemoved indicates that a user removed a policy
|
||||||
PolicyRemoved
|
PolicyRemoved Activity = 11
|
||||||
// SetupKeyCreated indicates that a user created a new setup key
|
// SetupKeyCreated indicates that a user created a new setup key
|
||||||
SetupKeyCreated
|
SetupKeyCreated Activity = 12
|
||||||
// SetupKeyUpdated indicates that a user updated a setup key
|
// SetupKeyUpdated indicates that a user updated a setup key
|
||||||
SetupKeyUpdated
|
SetupKeyUpdated Activity = 13
|
||||||
// SetupKeyRevoked indicates that a user revoked a setup key
|
// SetupKeyRevoked indicates that a user revoked a setup key
|
||||||
SetupKeyRevoked
|
SetupKeyRevoked Activity = 14
|
||||||
// SetupKeyOverused indicates that setup key usage exhausted
|
// SetupKeyOverused indicates that setup key usage exhausted
|
||||||
SetupKeyOverused
|
SetupKeyOverused Activity = 15
|
||||||
// GroupCreated indicates that a user created a group
|
// GroupCreated indicates that a user created a group
|
||||||
GroupCreated
|
GroupCreated Activity = 16
|
||||||
// GroupUpdated indicates that a user updated a group
|
// GroupUpdated indicates that a user updated a group
|
||||||
GroupUpdated
|
GroupUpdated Activity = 17
|
||||||
// GroupAddedToPeer indicates that a user added group to a peer
|
// GroupAddedToPeer indicates that a user added group to a peer
|
||||||
GroupAddedToPeer
|
GroupAddedToPeer Activity = 18
|
||||||
// GroupRemovedFromPeer indicates that a user removed peer group
|
// GroupRemovedFromPeer indicates that a user removed peer group
|
||||||
GroupRemovedFromPeer
|
GroupRemovedFromPeer Activity = 19
|
||||||
// GroupAddedToUser indicates that a user added group to a user
|
// GroupAddedToUser indicates that a user added group to a user
|
||||||
GroupAddedToUser
|
GroupAddedToUser Activity = 20
|
||||||
// GroupRemovedFromUser indicates that a user removed a group from a user
|
// GroupRemovedFromUser indicates that a user removed a group from a user
|
||||||
GroupRemovedFromUser
|
GroupRemovedFromUser Activity = 21
|
||||||
// UserRoleUpdated indicates that a user changed the role of a user
|
// UserRoleUpdated indicates that a user changed the role of a user
|
||||||
UserRoleUpdated
|
UserRoleUpdated Activity = 22
|
||||||
// GroupAddedToSetupKey indicates that a user added group to a setup key
|
// GroupAddedToSetupKey indicates that a user added group to a setup key
|
||||||
GroupAddedToSetupKey
|
GroupAddedToSetupKey Activity = 23
|
||||||
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
|
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
|
||||||
GroupRemovedFromSetupKey
|
GroupRemovedFromSetupKey Activity = 24
|
||||||
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
|
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
|
||||||
GroupAddedToDisabledManagementGroups
|
GroupAddedToDisabledManagementGroups Activity = 25
|
||||||
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
|
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
|
||||||
GroupRemovedFromDisabledManagementGroups
|
GroupRemovedFromDisabledManagementGroups Activity = 26
|
||||||
// RouteCreated indicates that a user created a route
|
// RouteCreated indicates that a user created a route
|
||||||
RouteCreated
|
RouteCreated Activity = 27
|
||||||
// RouteRemoved indicates that a user deleted a route
|
// RouteRemoved indicates that a user deleted a route
|
||||||
RouteRemoved
|
RouteRemoved Activity = 28
|
||||||
// RouteUpdated indicates that a user updated a route
|
// RouteUpdated indicates that a user updated a route
|
||||||
RouteUpdated
|
RouteUpdated Activity = 29
|
||||||
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
|
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
|
||||||
PeerSSHEnabled
|
PeerSSHEnabled Activity = 30
|
||||||
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
|
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
|
||||||
PeerSSHDisabled
|
PeerSSHDisabled Activity = 31
|
||||||
// PeerRenamed indicates that a user renamed a peer
|
// PeerRenamed indicates that a user renamed a peer
|
||||||
PeerRenamed
|
PeerRenamed Activity = 32
|
||||||
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
|
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
|
||||||
PeerLoginExpirationEnabled
|
PeerLoginExpirationEnabled Activity = 33
|
||||||
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
|
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
|
||||||
PeerLoginExpirationDisabled
|
PeerLoginExpirationDisabled Activity = 34
|
||||||
// NameserverGroupCreated indicates that a user created a nameservers group
|
// NameserverGroupCreated indicates that a user created a nameservers group
|
||||||
NameserverGroupCreated
|
NameserverGroupCreated Activity = 35
|
||||||
// NameserverGroupDeleted indicates that a user deleted a nameservers group
|
// NameserverGroupDeleted indicates that a user deleted a nameservers group
|
||||||
NameserverGroupDeleted
|
NameserverGroupDeleted Activity = 36
|
||||||
// NameserverGroupUpdated indicates that a user updated a nameservers group
|
// NameserverGroupUpdated indicates that a user updated a nameservers group
|
||||||
NameserverGroupUpdated
|
NameserverGroupUpdated Activity = 37
|
||||||
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
|
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
|
||||||
AccountPeerLoginExpirationEnabled
|
AccountPeerLoginExpirationEnabled Activity = 38
|
||||||
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
|
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
|
||||||
AccountPeerLoginExpirationDisabled
|
AccountPeerLoginExpirationDisabled Activity = 39
|
||||||
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
|
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
|
||||||
AccountPeerLoginExpirationDurationUpdated
|
AccountPeerLoginExpirationDurationUpdated Activity = 40
|
||||||
// PersonalAccessTokenCreated indicates that a user created a personal access token
|
// PersonalAccessTokenCreated indicates that a user created a personal access token
|
||||||
PersonalAccessTokenCreated
|
PersonalAccessTokenCreated Activity = 41
|
||||||
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
||||||
PersonalAccessTokenDeleted
|
PersonalAccessTokenDeleted Activity = 42
|
||||||
// ServiceUserCreated indicates that a user created a service user
|
// ServiceUserCreated indicates that a user created a service user
|
||||||
ServiceUserCreated
|
ServiceUserCreated Activity = 43
|
||||||
// ServiceUserDeleted indicates that a user deleted a service user
|
// ServiceUserDeleted indicates that a user deleted a service user
|
||||||
ServiceUserDeleted
|
ServiceUserDeleted Activity = 44
|
||||||
// UserBlocked indicates that a user blocked another user
|
// UserBlocked indicates that a user blocked another user
|
||||||
UserBlocked
|
UserBlocked Activity = 45
|
||||||
// UserUnblocked indicates that a user unblocked another user
|
// UserUnblocked indicates that a user unblocked another user
|
||||||
UserUnblocked
|
UserUnblocked Activity = 46
|
||||||
// UserDeleted indicates that a user deleted another user
|
// UserDeleted indicates that a user deleted another user
|
||||||
UserDeleted
|
UserDeleted Activity = 47
|
||||||
// GroupDeleted indicates that a user deleted group
|
// GroupDeleted indicates that a user deleted group
|
||||||
GroupDeleted
|
GroupDeleted Activity = 48
|
||||||
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
||||||
UserLoggedInPeer
|
UserLoggedInPeer Activity = 49
|
||||||
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
|
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
|
||||||
PeerLoginExpired
|
PeerLoginExpired Activity = 50
|
||||||
// DashboardLogin indicates that the user logged in to the dashboard
|
// DashboardLogin indicates that the user logged in to the dashboard
|
||||||
DashboardLogin
|
DashboardLogin Activity = 51
|
||||||
// IntegrationCreated indicates that the user created an integration
|
// IntegrationCreated indicates that the user created an integration
|
||||||
IntegrationCreated
|
IntegrationCreated Activity = 52
|
||||||
// IntegrationUpdated indicates that the user updated an integration
|
// IntegrationUpdated indicates that the user updated an integration
|
||||||
IntegrationUpdated
|
IntegrationUpdated Activity = 53
|
||||||
// IntegrationDeleted indicates that the user deleted an integration
|
// IntegrationDeleted indicates that the user deleted an integration
|
||||||
IntegrationDeleted
|
IntegrationDeleted Activity = 54
|
||||||
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
|
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
|
||||||
AccountPeerApprovalEnabled
|
AccountPeerApprovalEnabled Activity = 55
|
||||||
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
|
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
|
||||||
AccountPeerApprovalDisabled
|
AccountPeerApprovalDisabled Activity = 56
|
||||||
// PeerApproved indicates that the peer has been approved
|
// PeerApproved indicates that the peer has been approved
|
||||||
PeerApproved
|
PeerApproved Activity = 57
|
||||||
// PeerApprovalRevoked indicates that the peer approval has been revoked
|
// PeerApprovalRevoked indicates that the peer approval has been revoked
|
||||||
PeerApprovalRevoked
|
PeerApprovalRevoked Activity = 58
|
||||||
// TransferredOwnerRole indicates that the user transferred the owner role of the account
|
// TransferredOwnerRole indicates that the user transferred the owner role of the account
|
||||||
TransferredOwnerRole
|
TransferredOwnerRole Activity = 59
|
||||||
// PostureCheckCreated indicates that the user created a posture check
|
// PostureCheckCreated indicates that the user created a posture check
|
||||||
PostureCheckCreated
|
PostureCheckCreated Activity = 60
|
||||||
// PostureCheckUpdated indicates that the user updated a posture check
|
// PostureCheckUpdated indicates that the user updated a posture check
|
||||||
PostureCheckUpdated
|
PostureCheckUpdated Activity = 61
|
||||||
// PostureCheckDeleted indicates that the user deleted a posture check
|
// PostureCheckDeleted indicates that the user deleted a posture check
|
||||||
PostureCheckDeleted
|
PostureCheckDeleted Activity = 62
|
||||||
)
|
)
|
||||||
|
|
||||||
var activityMap = map[Activity]Code{
|
var activityMap = map[Activity]Code{
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
|
|||||||
|
|
||||||
func TestAccounts_AccountsHandler(t *testing.T) {
|
func TestAccounts_AccountsHandler(t *testing.T) {
|
||||||
accountID := "test_account"
|
accountID := "test_account"
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
|
|
||||||
sr := func(v string) *string { return &v }
|
sr := func(v string) *string { return &v }
|
||||||
br := func(v bool) *bool { return &v }
|
br := func(v bool) *bool { return &v }
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
|
|||||||
Id: testDNSSettingsAccountID,
|
Id: testDNSSettingsAccountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
testDNSSettingsUserID: server.NewAdminUser("test_user"),
|
testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
|
||||||
},
|
},
|
||||||
DNSSettings: baseExistingDNSSettings,
|
DNSSettings: baseExistingDNSSettings,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
accountID := "test_account"
|
accountID := "test_account"
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
events := generateEvents(accountID, adminUser.Id)
|
events := generateEvents(accountID, adminUser.Id)
|
||||||
handler := initEventsTestData(accountID, adminUser, events...)
|
handler := initEventsTestData(accountID, adminUser, events...)
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
|||||||
return &GeolocationsHandler{
|
return &GeolocationsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
user := server.NewAdminUser("test_user", "account_id")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
|
|||||||
Name: "Group",
|
Name: "Group",
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
p := initGroupTestData(adminUser, group)
|
p := initGroupTestData(adminUser, group)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
p := initGroupTestData(adminUser)
|
p := initGroupTestData(adminUser)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
p := initGroupTestData(adminUser)
|
p := initGroupTestData(adminUser)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
s "github.com/netbirdio/netbird/management/server"
|
s "github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
@@ -38,7 +39,7 @@ type emptyObject struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
|
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||||
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
|||||||
AuthCfg: authCfg,
|
AuthCfg: authCfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
|
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
|
||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
|
|||||||
Id: testNSGroupAccountID,
|
Id: testNSGroupAccountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": server.NewAdminUser("test_user"),
|
"test_user": server.NewAdminUser("test_user", "account_id"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
return "netbird.selfhosted"
|
return "netbird.selfhosted"
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
user := server.NewAdminUser("test_user", "account_id")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
user := server.NewAdminUser("test_user", "account_id")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
|||||||
return accountPostureChecks, nil
|
return accountPostureChecks, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
user := server.NewAdminUser("test_user", "account_id")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": server.NewAdminUser("test_user"),
|
"test_user": server.NewAdminUser("test_user", "account_id"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
defaultSetupKey := server.GenerateDefaultSetupKey()
|
defaultSetupKey := server.GenerateDefaultSetupKey()
|
||||||
defaultSetupKey.Id = existingSetupKeyID
|
defaultSetupKey.Id = existingSetupKeyID
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||||
|
|
||||||
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
||||||
server.SetupKeyUnlimitedUsage, true)
|
server.SetupKeyUnlimitedUsage, true)
|
||||||
|
|||||||
@@ -115,7 +115,15 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
|
|||||||
data.Set("client_id", ac.clientConfig.ClientID)
|
data.Set("client_id", ac.clientConfig.ClientID)
|
||||||
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
||||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||||
data.Set("scope", "https://graph.microsoft.com/.default")
|
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// get base url and add "/.default" as scope
|
||||||
|
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||||
|
scopeURL := baseURL + "/.default"
|
||||||
|
data.Set("scope", scopeURL)
|
||||||
|
|
||||||
payload := strings.NewReader(data.Encode())
|
payload := strings.NewReader(data.Encode())
|
||||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ func (om *OktaManager) DeleteUser(userID string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseOktaUserToUserData parse okta user to UserData.
|
// parseOktaUser parse okta user to UserData.
|
||||||
func parseOktaUser(user *okta.User) (*UserData, error) {
|
func parseOktaUser(user *okta.User) (*UserData, error) {
|
||||||
var oktaUser struct {
|
var oktaUser struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
|||||||
@@ -706,7 +706,7 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
||||||
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
|
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
|
||||||
if am.UpdateIntegratedValidatorGroupsFunc != nil {
|
if am.UpdateIntegratedValidatorGroupsFunc != nil {
|
||||||
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)
|
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)
|
||||||
|
|||||||
@@ -551,8 +551,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
|
|||||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||||
}
|
}
|
||||||
|
|
||||||
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||||
if requiresApproval {
|
if peerNotValid {
|
||||||
emptyMap := &NetworkMap{
|
emptyMap := &NetworkMap{
|
||||||
Network: account.Network.Copy(),
|
Network: account.Network.Copy(),
|
||||||
}
|
}
|
||||||
@@ -563,11 +563,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
|
|||||||
am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
}
|
}
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
validPeersMap, err := am.GetValidatedPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
|
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginPeer logs in or registers a peer.
|
// LoginPeer logs in or registers a peer.
|
||||||
|
|||||||
@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
select {
|
select {
|
||||||
case <-cancel:
|
case <-cancel:
|
||||||
log.Debugf("scheduled job %s was canceled, stop timer", ID)
|
log.Tracef("scheduled job %s was canceled, stop timer", ID)
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
log.Debugf("time to do a scheduled job %s", ID)
|
log.Tracef("time to do a scheduled job %s", ID)
|
||||||
}
|
}
|
||||||
runIn, reschedule := job()
|
runIn, reschedule := job()
|
||||||
if !reschedule {
|
if !reschedule {
|
||||||
wm.mu.Lock()
|
wm.mu.Lock()
|
||||||
defer wm.mu.Unlock()
|
defer wm.mu.Unlock()
|
||||||
delete(wm.jobs, ID)
|
delete(wm.jobs, ID)
|
||||||
log.Debugf("job %s is not scheduled to run again", ID)
|
log.Tracef("job %s is not scheduled to run again", ID)
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
|
|||||||
ticker.Reset(runIn)
|
ticker.Reset(runIn)
|
||||||
}
|
}
|
||||||
case <-cancel:
|
case <-cancel:
|
||||||
log.Debugf("job %s was canceled, stopping timer", ID)
|
log.Tracef("job %s was canceled, stopping timer", ID)
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
|
|||||||
return unlock
|
return unlock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
|
||||||
|
// Get the reflect.Value of the records slice
|
||||||
|
v := reflect.ValueOf(records)
|
||||||
|
if v.Kind() != reflect.Slice {
|
||||||
|
return fmt.Errorf("provided input is not a slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert records in batches
|
||||||
|
for i := 0; i < v.Len(); i += batchSize {
|
||||||
|
end := i + batchSize
|
||||||
|
if end > v.Len() {
|
||||||
|
end = v.Len()
|
||||||
|
}
|
||||||
|
// Use reflect.Slice to get a slice of the records for the current batch
|
||||||
|
batch := v.Slice(i, end).Interface()
|
||||||
|
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqliteStore) SaveAccount(account *Account) error {
|
func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
for _, key := range account.SetupKeys {
|
// operate over a fresh copy as we will modify its fields
|
||||||
account.SetupKeysG = append(account.SetupKeysG, *key)
|
accCopy := account.Copy()
|
||||||
|
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
|
||||||
|
for _, key := range accCopy.SetupKeys {
|
||||||
|
//we need an explicit reference to the account for gorm
|
||||||
|
key.AccountID = accCopy.Id
|
||||||
|
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, peer := range account.Peers {
|
accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
|
||||||
|
for id, peer := range accCopy.Peers {
|
||||||
peer.ID = id
|
peer.ID = id
|
||||||
account.PeersG = append(account.PeersG, *peer)
|
//we need an explicit reference to the account for gorm
|
||||||
|
peer.AccountID = accCopy.Id
|
||||||
|
accCopy.PeersG = append(accCopy.PeersG, *peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, user := range account.Users {
|
accCopy.UsersG = make([]User, 0, len(accCopy.Users))
|
||||||
|
for id, user := range accCopy.Users {
|
||||||
user.Id = id
|
user.Id = id
|
||||||
|
//we need an explicit reference to the account for gorm
|
||||||
|
user.AccountID = accCopy.Id
|
||||||
|
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
|
||||||
for id, pat := range user.PATs {
|
for id, pat := range user.PATs {
|
||||||
pat.ID = id
|
pat.ID = id
|
||||||
user.PATsG = append(user.PATsG, *pat)
|
user.PATsG = append(user.PATsG, *pat)
|
||||||
}
|
}
|
||||||
account.UsersG = append(account.UsersG, *user)
|
accCopy.UsersG = append(accCopy.UsersG, *user)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, group := range account.Groups {
|
accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
|
||||||
|
for id, group := range accCopy.Groups {
|
||||||
group.ID = id
|
group.ID = id
|
||||||
account.GroupsG = append(account.GroupsG, *group)
|
//we need an explicit reference to the account for gorm
|
||||||
|
group.AccountID = accCopy.Id
|
||||||
|
accCopy.GroupsG = append(accCopy.GroupsG, *group)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, route := range account.Routes {
|
accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
|
||||||
|
for id, route := range accCopy.Routes {
|
||||||
route.ID = id
|
route.ID = id
|
||||||
account.RoutesG = append(account.RoutesG, *route)
|
//we need an explicit reference to the account for gorm
|
||||||
|
route.AccountID = accCopy.Id
|
||||||
|
accCopy.RoutesG = append(accCopy.RoutesG, *route)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, ns := range account.NameServerGroups {
|
accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
|
||||||
|
for id, ns := range accCopy.NameServerGroups {
|
||||||
ns.ID = id
|
ns.ID = id
|
||||||
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
|
//we need an explicit reference to the account for gorm
|
||||||
|
ns.AccountID = accCopy.Id
|
||||||
|
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
|
result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tx.Select(clause.Associations).Delete(account)
|
result = tx.Select(clause.Associations).Delete(accCopy)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tx.
|
result = tx.
|
||||||
Session(&gorm.Session{FullSaveAssociations: true}).
|
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
|
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||||
|
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
|
||||||
|
Create(accCopy)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
const batchSize = 500
|
||||||
|
err := batchInsert(accCopy.PeersG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = batchInsert(accCopy.UsersG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = batchInsert(accCopy.GroupsG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = batchInsert(accCopy.RoutesG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
|
||||||
})
|
})
|
||||||
|
|
||||||
took := time.Since(start)
|
took := time.Since(start)
|
||||||
if s.metrics != nil {
|
if s.metrics != nil {
|
||||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||||
}
|
}
|
||||||
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
|
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
|
|||||||
func (s *SqliteStore) DeleteAccount(account *Account) error {
|
func (s *SqliteStore) DeleteAccount(account *Account) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
|
account.UsersG = make([]User, 0, len(account.Users))
|
||||||
|
for id, user := range account.Users {
|
||||||
|
user.Id = id
|
||||||
|
//we need an explicit reference to an account as it is missing for some reason
|
||||||
|
user.AccountID = account.Id
|
||||||
|
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
|
||||||
|
for id, pat := range user.PATs {
|
||||||
|
pat.ID = id
|
||||||
|
user.PATsG = append(user.PATsG, *pat)
|
||||||
|
}
|
||||||
|
account.UsersG = append(account.UsersG, *user)
|
||||||
|
}
|
||||||
|
|
||||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
|
|||||||
@@ -2,7 +2,12 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
route2 "github.com/netbirdio/netbird/route"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) {
|
|||||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func TestSqlite_SaveAccount_Large(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStore(t)
|
||||||
|
|
||||||
|
account := newAccountWithId("account_id", "testuser", "")
|
||||||
|
groupALL, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
setupKey := GenerateDefaultSetupKey()
|
||||||
|
account.SetupKeys[setupKey.Key] = setupKey
|
||||||
|
const numPerAccount = 2000
|
||||||
|
for n := 0; n < numPerAccount; n++ {
|
||||||
|
netIP := randomIPv4()
|
||||||
|
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
|
||||||
|
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
ID: peerID,
|
||||||
|
Key: peerID,
|
||||||
|
SetupKey: "",
|
||||||
|
IP: netIP,
|
||||||
|
Name: peerID,
|
||||||
|
DNSLabel: peerID,
|
||||||
|
UserID: userID,
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
SSHEnabled: false,
|
||||||
|
}
|
||||||
|
account.Peers[peerID] = peer
|
||||||
|
group, _ := account.GetGroupAll()
|
||||||
|
group.Peers = append(group.Peers, peerID)
|
||||||
|
user := &User{
|
||||||
|
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
|
||||||
|
AccountID: account.Id,
|
||||||
|
}
|
||||||
|
account.Users[user.Id] = user
|
||||||
|
route := &route2.Route{
|
||||||
|
ID: fmt.Sprintf("network-id-%d", n),
|
||||||
|
Description: "base route",
|
||||||
|
NetID: fmt.Sprintf("network-id-%d", n),
|
||||||
|
Network: netip.MustParsePrefix(netIP.String() + "/24"),
|
||||||
|
NetworkType: route2.IPv4Network,
|
||||||
|
Metric: 9999,
|
||||||
|
Masquerade: false,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{groupALL.ID},
|
||||||
|
}
|
||||||
|
account.Routes[route.ID] = route
|
||||||
|
|
||||||
|
group = &nbgroup.Group{
|
||||||
|
ID: fmt.Sprintf("group-id-%d", n),
|
||||||
|
AccountID: account.Id,
|
||||||
|
Name: fmt.Sprintf("group-id-%d", n),
|
||||||
|
Issued: "api",
|
||||||
|
Peers: nil,
|
||||||
|
}
|
||||||
|
account.Groups[group.ID] = group
|
||||||
|
|
||||||
|
nameserver := &nbdns.NameServerGroup{
|
||||||
|
ID: fmt.Sprintf("nameserver-id-%d", n),
|
||||||
|
AccountID: account.Id,
|
||||||
|
Name: fmt.Sprintf("nameserver-id-%d", n),
|
||||||
|
Description: "",
|
||||||
|
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
|
||||||
|
Groups: []string{group.ID},
|
||||||
|
Primary: false,
|
||||||
|
Domains: nil,
|
||||||
|
Enabled: false,
|
||||||
|
SearchDomainsEnabled: false,
|
||||||
|
}
|
||||||
|
account.NameServerGroups[nameserver.ID] = nameserver
|
||||||
|
|
||||||
|
setupKey := GenerateDefaultSetupKey()
|
||||||
|
account.SetupKeys[setupKey.Key] = setupKey
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.SaveAccount(account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if len(store.GetAllAccounts()) != 1 {
|
||||||
|
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := store.GetAccount(account.Id)
|
||||||
|
if a == nil {
|
||||||
|
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Policies) != 1 {
|
||||||
|
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Policies[0].Rules) != 1 {
|
||||||
|
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Peers) != numPerAccount {
|
||||||
|
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount, len(a.Peers))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Users) != numPerAccount+1 {
|
||||||
|
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount+1, len(a.Users))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.Routes) != numPerAccount {
|
||||||
|
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount, len(a.Routes))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.NameServerGroups) != numPerAccount {
|
||||||
|
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount, len(a.NameServerGroups))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.NameServerGroups) != numPerAccount {
|
||||||
|
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount, len(a.NameServerGroups))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
|
||||||
|
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
|
||||||
|
numPerAccount+1, len(a.SetupKeys))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqlite_SaveAccount(t *testing.T) {
|
func TestSqlite_SaveAccount(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
@@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
|||||||
Name: "peer name",
|
Name: "peer name",
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||||
}
|
}
|
||||||
|
admin := account.Users["testuser"]
|
||||||
|
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||||
|
ID: "testtoken",
|
||||||
|
Name: "test token",
|
||||||
|
HashedToken: "hashed token",
|
||||||
|
}}
|
||||||
|
|
||||||
err := store.SaveAccount(account)
|
err := store.SaveAccount(account)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
|||||||
store := newSqliteStore(t)
|
store := newSqliteStore(t)
|
||||||
|
|
||||||
testUserID := "testuser"
|
testUserID := "testuser"
|
||||||
user := NewAdminUser(testUserID)
|
user := NewAdminUser(testUserID, "account_id")
|
||||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||||
ID: "testtoken",
|
ID: "testtoken",
|
||||||
Name: "test token",
|
Name: "test token",
|
||||||
@@ -393,3 +539,12 @@ func newAccount(store Store, id int) error {
|
|||||||
|
|
||||||
return store.SaveAccount(account)
|
return store.SaveAccount(account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randomIPv4() net.IP {
|
||||||
|
rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
b := make([]byte, 4)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = byte(rand.Intn(256))
|
||||||
|
}
|
||||||
|
return net.IP(b)
|
||||||
|
}
|
||||||
|
|||||||
@@ -180,9 +180,11 @@ func (u *User) Copy() *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewUser creates a new user
|
// NewUser creates a new user
|
||||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
|
||||||
|
accountID string) *User {
|
||||||
return &User{
|
return &User{
|
||||||
Id: id,
|
Id: ID,
|
||||||
|
AccountID: accountID,
|
||||||
Role: role,
|
Role: role,
|
||||||
IsServiceUser: isServiceUser,
|
IsServiceUser: isServiceUser,
|
||||||
NonDeletable: nonDeletable,
|
NonDeletable: nonDeletable,
|
||||||
@@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewRegularUser creates a new user with role UserRoleUser
|
// NewRegularUser creates a new user with role UserRoleUser
|
||||||
func NewRegularUser(id string) *User {
|
func NewRegularUser(ID, accountID string) *User {
|
||||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
|
||||||
|
accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||||
func NewAdminUser(id string) *User {
|
func NewAdminUser(ID, accountID string) *User {
|
||||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
|
||||||
|
accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOwnerUser creates a new user with role UserRoleOwner
|
// NewOwnerUser creates a new user with role UserRoleOwner
|
||||||
func NewOwnerUser(id string) *User {
|
func NewOwnerUser(ID, accountID string) *User {
|
||||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
|
return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
|
||||||
|
accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createServiceUser creates a new service user under the given account.
|
// createServiceUser creates a new service user under the given account.
|
||||||
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
|
||||||
|
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
|
|||||||
}
|
}
|
||||||
|
|
||||||
newUserID := uuid.New().String()
|
newUserID := uuid.New().String()
|
||||||
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
|
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
|
||||||
log.Debugf("New User: %v", newUser)
|
log.Debugf("New User: %v", newUser)
|
||||||
account.Users[newUserID] = newUser
|
account.Users[newUserID] = newUser
|
||||||
|
|
||||||
|
|||||||
@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
|||||||
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||||
account.Users["normal_user1"] = NewRegularUser("normal_user1")
|
account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
|
||||||
account.Users["normal_user2"] = NewRegularUser("normal_user2")
|
account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
|
||||||
|
|
||||||
err := store.SaveAccount(account)
|
err := store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
|||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||||
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
|
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
|
||||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||||
delete(account.Users, mockUserID)
|
delete(account.Users, mockUserID)
|
||||||
|
|
||||||
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
|||||||
|
|
||||||
func TestUser_IsAdmin(t *testing.T) {
|
func TestUser_IsAdmin(t *testing.T) {
|
||||||
|
|
||||||
user := NewAdminUser(mockUserID)
|
user := NewAdminUser(mockUserID, mockAccountID)
|
||||||
assert.True(t, user.HasAdminPower())
|
assert.True(t, user.HasAdminPower())
|
||||||
|
|
||||||
user = NewRegularUser(mockUserID)
|
user = NewRegularUser(mockUserID, mockAccountID)
|
||||||
assert.False(t, user.HasAdminPower())
|
assert.False(t, user.HasAdminPower())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create other users
|
// create other users
|
||||||
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
|
||||||
account.Users[adminUserID] = NewAdminUser(adminUserID)
|
account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
|
||||||
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
|
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
|
||||||
err = manager.Store.SaveAccount(account)
|
err = manager.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
38
util/grpc/dialer.go
Normal file
38
util/grpc/dialer.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WithCustomDialer() grpc.DialOption {
|
||||||
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to get current user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||||
|
if currentUser.Uid != "0" {
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
return dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to dial: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
package grpc
|
|
||||||
|
|
||||||
import "google.golang.org/grpc"
|
|
||||||
|
|
||||||
func WithCustomDialer() grpc.DialOption {
|
|
||||||
return grpc.EmptyDialOption{}
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func WithCustomDialer() grpc.DialOption {
|
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
return nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
21
util/net/dialer.go
Normal file
21
util/net/dialer.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialer extends the standard net.Dialer with the ability to execute hooks before
|
||||||
|
// and after connections. This can be used to bypass the VPN for connections using this dialer.
|
||||||
|
type Dialer struct {
|
||||||
|
*net.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDialer returns a customized net.Dialer with overridden Control method
|
||||||
|
func NewDialer() *Dialer {
|
||||||
|
dialer := &Dialer{
|
||||||
|
Dialer: &net.Dialer{},
|
||||||
|
}
|
||||||
|
dialer.init()
|
||||||
|
|
||||||
|
return dialer
|
||||||
|
}
|
||||||
@@ -1,19 +1,175 @@
|
|||||||
//go:build !linux || android
|
//go:build !android && !ios
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewDialer() *net.Dialer {
|
type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
|
||||||
return &net.Dialer{}
|
type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
|
||||||
|
|
||||||
|
var (
|
||||||
|
dialerDialHooksMutex sync.RWMutex
|
||||||
|
dialerDialHooks []DialerDialHookFunc
|
||||||
|
dialerCloseHooksMutex sync.RWMutex
|
||||||
|
dialerCloseHooks []DialerCloseHookFunc
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddDialerHook allows adding a new hook to be executed before dialing.
|
||||||
|
func AddDialerHook(hook DialerDialHookFunc) {
|
||||||
|
dialerDialHooksMutex.Lock()
|
||||||
|
defer dialerDialHooksMutex.Unlock()
|
||||||
|
dialerDialHooks = append(dialerDialHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDialerCloseHook allows adding a new hook to be executed on connection close.
|
||||||
|
func AddDialerCloseHook(hook DialerCloseHookFunc) {
|
||||||
|
dialerCloseHooksMutex.Lock()
|
||||||
|
defer dialerCloseHooksMutex.Unlock()
|
||||||
|
dialerCloseHooks = append(dialerCloseHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDialerHook removes all dialer hooks.
|
||||||
|
func RemoveDialerHooks() {
|
||||||
|
dialerDialHooksMutex.Lock()
|
||||||
|
defer dialerDialHooksMutex.Unlock()
|
||||||
|
dialerDialHooks = nil
|
||||||
|
|
||||||
|
dialerCloseHooksMutex.Lock()
|
||||||
|
defer dialerCloseHooksMutex.Unlock()
|
||||||
|
dialerCloseHooks = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
||||||
|
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return d.Dialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resolver *net.Resolver
|
||||||
|
if d.Resolver != nil {
|
||||||
|
resolver = d.Resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
connID := GenerateConnID()
|
||||||
|
if dialerDialHooks != nil {
|
||||||
|
if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
|
||||||
|
log.Errorf("Failed to call dialer hooks: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the connection in Conn to handle Close with hooks
|
||||||
|
return &Conn{Conn: conn, ID: connID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial wraps the net.Dialer's Dial method to use the custom connection
|
||||||
|
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||||
|
return d.DialContext(context.Background(), network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn wraps a net.Conn to override the Close method
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
ID ConnectionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
err := c.Conn.Close()
|
||||||
|
|
||||||
|
dialerCloseHooksMutex.RLock()
|
||||||
|
defer dialerCloseHooksMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range dialerCloseHooks {
|
||||||
|
if err := hook(c.ID, &c.Conn); err != nil {
|
||||||
|
log.Errorf("Error executing dialer close hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
||||||
|
host, _, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("split host and port: %w", err)
|
||||||
|
}
|
||||||
|
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
|
dialerDialHooksMutex.RLock()
|
||||||
|
defer dialerDialHooksMutex.RUnlock()
|
||||||
|
for _, hook := range dialerDialHooks {
|
||||||
|
if err := hook(ctx, connID, ips); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
return net.DialUDP(network, laddr, raddr)
|
return net.DialUDP(network, laddr, raddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return udpConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
return net.DialTCP(network, laddr, raddr)
|
return net.DialTCP(network, laddr, raddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tcpConn, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,59 +2,11 @@
|
|||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import "syscall"
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
// init configures the net.Dialer Control function to set the fwmark on the socket
|
||||||
)
|
func (d *Dialer) init() {
|
||||||
|
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||||
func NewDialer() *net.Dialer {
|
|
||||||
return &net.Dialer{
|
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
|
||||||
return SetRawSocketMark(c)
|
return SetRawSocketMark(c)
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
udpConn, ok := conn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDP connection, got different type")
|
|
||||||
}
|
|
||||||
|
|
||||||
return udpConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpConn, ok := conn.(*net.TCPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected TCP connection, got different type")
|
|
||||||
}
|
|
||||||
|
|
||||||
return tcpConn, nil
|
|
||||||
}
|
|
||||||
|
|||||||
15
util/net/dialer_mobile.go
Normal file
15
util/net/dialer_mobile.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build android || ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||||
|
return net.DialUDP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||||
|
return net.DialTCP(network, laddr, raddr)
|
||||||
|
}
|
||||||
6
util/net/dialer_nonlinux.go
Normal file
6
util/net/dialer_nonlinux.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
func (d *Dialer) init() {
|
||||||
|
}
|
||||||
21
util/net/listener.go
Normal file
21
util/net/listener.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
|
||||||
|
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
|
||||||
|
type ListenerConfig struct {
|
||||||
|
*net.ListenConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewListener creates a new ListenerConfig instance.
|
||||||
|
func NewListener() *ListenerConfig {
|
||||||
|
listener := &ListenerConfig{
|
||||||
|
ListenConfig: &net.ListenConfig{},
|
||||||
|
}
|
||||||
|
listener.init()
|
||||||
|
|
||||||
|
return listener
|
||||||
|
}
|
||||||
@@ -1,13 +1,172 @@
|
|||||||
//go:build !linux || android
|
//go:build !android && !ios
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
import "net"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
func NewListener() *net.ListenConfig {
|
"github.com/pion/transport/v3"
|
||||||
return &net.ListenConfig{}
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn.
|
||||||
|
type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error
|
||||||
|
|
||||||
|
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
|
||||||
|
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
|
||||||
|
|
||||||
|
var (
|
||||||
|
listenerWriteHooksMutex sync.RWMutex
|
||||||
|
listenerWriteHooks []ListenerWriteHookFunc
|
||||||
|
listenerCloseHooksMutex sync.RWMutex
|
||||||
|
listenerCloseHooks []ListenerCloseHookFunc
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
|
||||||
|
func AddListenerWriteHook(hook ListenerWriteHookFunc) {
|
||||||
|
listenerWriteHooksMutex.Lock()
|
||||||
|
defer listenerWriteHooksMutex.Unlock()
|
||||||
|
listenerWriteHooks = append(listenerWriteHooks, hook)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) {
|
// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection.
|
||||||
return net.ListenUDP(network, locAddr)
|
func AddListenerCloseHook(hook ListenerCloseHookFunc) {
|
||||||
|
listenerCloseHooksMutex.Lock()
|
||||||
|
defer listenerCloseHooksMutex.Unlock()
|
||||||
|
listenerCloseHooks = append(listenerCloseHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveListenerHooks removes all dialer hooks.
|
||||||
|
func RemoveListenerHooks() {
|
||||||
|
listenerWriteHooksMutex.Lock()
|
||||||
|
defer listenerWriteHooksMutex.Unlock()
|
||||||
|
listenerWriteHooks = nil
|
||||||
|
|
||||||
|
listenerCloseHooksMutex.Lock()
|
||||||
|
defer listenerCloseHooksMutex.Unlock()
|
||||||
|
listenerCloseHooks = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenPacket listens on the network address and returns a PacketConn
|
||||||
|
// which includes support for write hooks.
|
||||||
|
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return l.ListenConfig.ListenPacket(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen packet: %w", err)
|
||||||
|
}
|
||||||
|
connID := GenerateConnID()
|
||||||
|
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
|
||||||
|
type PacketConn struct {
|
||||||
|
net.PacketConn
|
||||||
|
ID ConnectionID
|
||||||
|
seenAddrs *sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||||
|
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
callWriteHooks(c.ID, c.seenAddrs, b, addr)
|
||||||
|
return c.PacketConn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
||||||
|
func (c *PacketConn) Close() error {
|
||||||
|
c.seenAddrs = &sync.Map{}
|
||||||
|
return closeConn(c.ID, c.PacketConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
|
||||||
|
type UDPConn struct {
|
||||||
|
*net.UDPConn
|
||||||
|
ID ConnectionID
|
||||||
|
seenAddrs *sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||||
|
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
callWriteHooks(c.ID, c.seenAddrs, b, addr)
|
||||||
|
return c.UDPConn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
||||||
|
func (c *UDPConn) Close() error {
|
||||||
|
c.seenAddrs = &sync.Map{}
|
||||||
|
return closeConn(c.ID, c.UDPConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
|
||||||
|
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
|
||||||
|
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
|
||||||
|
ipStr, _, splitErr := net.SplitHostPort(addr.String())
|
||||||
|
if splitErr != nil {
|
||||||
|
log.Errorf("Error splitting IP address and port: %v", splitErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := net.ResolveIPAddr("ip", ipStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error resolving IP address: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("Listener resolved IP for %s: %s", addr, ip)
|
||||||
|
|
||||||
|
func() {
|
||||||
|
listenerWriteHooksMutex.RLock()
|
||||||
|
defer listenerWriteHooksMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range listenerWriteHooks {
|
||||||
|
if err := hook(id, ip, b); err != nil {
|
||||||
|
log.Errorf("Error executing listener write hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeConn(id ConnectionID, conn net.PacketConn) error {
|
||||||
|
err := conn.Close()
|
||||||
|
|
||||||
|
listenerCloseHooksMutex.RLock()
|
||||||
|
defer listenerCloseHooksMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range listenerCloseHooks {
|
||||||
|
if err := hook(id, conn); err != nil {
|
||||||
|
log.Errorf("Error executing listener close hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||||
|
// which includes support for write and close hooks.
|
||||||
|
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.ListenUDP(network, laddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
packetConn := conn.(*PacketConn)
|
||||||
|
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := packetConn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,28 +3,12 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener() *net.ListenConfig {
|
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||||
return &net.ListenConfig{
|
func (l *ListenerConfig) init() {
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
||||||
return SetRawSocketMark(c)
|
return SetRawSocketMark(c)
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
||||||
pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err)
|
|
||||||
}
|
|
||||||
udpConn, ok := pc.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("packetConn is not a *net.UDPConn")
|
|
||||||
}
|
|
||||||
return udpConn, nil
|
|
||||||
}
|
|
||||||
|
|||||||
11
util/net/listener_mobile.go
Normal file
11
util/net/listener_mobile.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//go:build android || ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||||
|
return net.ListenUDP(network, laddr)
|
||||||
|
}
|
||||||
6
util/net/listener_nonlinux.go
Normal file
6
util/net/listener_nonlinux.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
func (l *ListenerConfig) init() {
|
||||||
|
}
|
||||||
@@ -1,6 +1,27 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
||||||
NetbirdFwmark = 0x1BD00
|
NetbirdFwmark = 0x1BD00
|
||||||
|
|
||||||
|
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ConnectionID provides a globally unique identifier for network connections.
|
||||||
|
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
||||||
|
type ConnectionID string
|
||||||
|
|
||||||
|
// GenerateConnID generates a unique identifier for each connection.
|
||||||
|
func GenerateConnID() ConnectionID {
|
||||||
|
return ConnectionID(uuid.NewString())
|
||||||
|
}
|
||||||
|
|
||||||
|
func CustomRoutingDisabled() bool {
|
||||||
|
return os.Getenv(envDisableCustomRouting) == "true"
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
|||||||
var setErr error
|
var setErr error
|
||||||
|
|
||||||
err := conn.Control(func(fd uintptr) {
|
err := conn.Control(func(fd uintptr) {
|
||||||
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
setErr = SetSocketOpt(int(fd))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("control: %w", err)
|
return fmt.Errorf("control: %w", err)
|
||||||
@@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetSocketOpt(fd int) error {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user