Compare commits

...

33 Commits

Author SHA1 Message Date
Zoltán Papp
46b564db93 Add naive speed test 2024-04-25 15:03:55 +02:00
Zoltan Papp
c6a282f1e3 Fix on mac 2024-04-24 19:25:08 +02:00
Zoltán Papp
adb787954e Discover NAT side ip with STUN 2024-04-24 18:59:55 +02:00
Zoltán Papp
318d379658 Update wireguard-go 2024-04-22 19:13:07 +02:00
Zoltán Papp
b68a02acee Close turn connection
Without it the WG can not exit from the read loop
2024-04-18 15:59:50 +02:00
Zoltán Papp
b5c4802bb9 Apply new receiver functions 2024-04-16 16:01:25 +02:00
Zoltán Papp
28a9a2ef87 Move receiverCreator to new file 2024-04-15 14:12:24 +02:00
Zoltán Papp
b355c34b63 Configure wg with proper address 2024-04-12 18:29:35 +02:00
Zoltán Papp
d67f766b2e Initial code 2024-04-12 17:38:31 +02:00
Carlos Hernandez
76702c8a09 Add safe read/write to route map (#1760) 2024-04-11 22:12:23 +02:00
Viktor Liu
061f673a4f Don't use the custom dialer as non-root (#1823) 2024-04-11 15:29:03 +02:00
Zoltan Papp
9505805313 Rename variable (#1829) 2024-04-11 14:08:03 +02:00
Maycon Santos
704c67dec8 Allow owners that did not create the account to delete it (#1825)
Sometimes the Owner role will be passed to new users, and they need to be able to delete the account
2024-04-11 10:02:51 +02:00
pascal-fischer
3ed2f08f3c Add latency based routing (#1732)
Now that we have the latency between peers available we can use this data to consider when choosing the best route. This way the route with the routing peer with the lower latency will be preferred over others with the same target network.
2024-04-09 21:20:02 +02:00
Maycon Santos
4c83408f27 Add log-level to the management's docker service command (#1820) 2024-04-09 21:00:43 +02:00
Viktor Liu
90bd39c740 Log panics (#1818) 2024-04-09 20:27:27 +02:00
Maycon Santos
dd0cf41147 Auto restart Windows agent daemon service (#1819)
This enables auto restart of the windows agent daemon service on event of failure
2024-04-09 20:10:59 +02:00
pascal-fischer
22b2caffc6 Remove dns based cloud detection (#1812)
* remove dns based cloud checks

* remove dns based cloud checks
2024-04-09 19:01:31 +02:00
Viktor Liu
c1f66d1354 Retry macOS route command (#1817) 2024-04-09 15:27:19 +02:00
Viktor Liu
ac0fe6025b Fix routing issues with MacOS (#1815)
* Handle zones properly

* Use host routes for single IPs 

* Add GOOS and GOARCH to startup log

* Log powershell command
2024-04-09 13:25:14 +02:00
verytrap
c28657710a Fix function names in comments (#1816)
Signed-off-by: verytrap <wangqiuyue@outlook.com>
2024-04-09 13:18:38 +02:00
Maycon Santos
3875c29f6b Revert "Rollback new routing functionality (#1805)" (#1813)
This reverts commit 9f32ccd453.
2024-04-08 18:56:52 +02:00
Viktor Liu
9f32ccd453 Rollback new routing functionality (#1805) 2024-04-05 20:38:49 +02:00
trax
1d1d057e7d Change the dashboard image pull from wiretrustee to netbirdio (#1804) 2024-04-05 13:51:28 +02:00
Viktor Liu
3461b1bb90 Expect correct conn type (#1801) 2024-04-05 00:10:32 +02:00
Viktor Liu
3d2a2377c6 Don't return errors on disallowed routes (#1792) 2024-04-03 19:06:04 +02:00
Viktor Liu
25f5f26527 Timeout rule removing loop and catch IPv6 unsupported error in loop (#1791) 2024-04-03 18:57:50 +02:00
Viktor Liu
bb0d5c5baf Linux legacy routing (#1774)
* Add Linux legacy routing if ip rule functionality is not available

* Ignore exclusion route errors if host has no route

* Exclude iOS from route manager

* Also retrieve IPv6 routes

* Ignore loopback addresses not being in the main table

* Ignore "not supported" errors on cleanup

* Fix regression in ListenUDP not using fwmarks
2024-04-03 18:04:22 +02:00
Viktor Liu
7938295190 Feature/exit nodes - Windows and macOS support (#1726) 2024-04-03 11:11:46 +02:00
rqi14
9af532fe71 Get scope from endpoint url instead of hardcoding (#1770) 2024-04-02 13:43:57 +02:00
Vilian Gerdzhikov
23a1473797 Fix grammar in readme (#1778) 2024-04-02 10:08:58 +02:00
Misha Bragin
9c2dc05df1 Eval/higher timeouts (#1776) 2024-03-31 19:39:52 +02:00
Misha Bragin
40d56e5d29 Update network security image (#1765) 2024-03-28 18:43:32 +01:00
79 changed files with 3643 additions and 1242 deletions

View File

@@ -32,6 +32,9 @@ jobs:
restore-keys: |
macos-go-
- name: Install libpcap
run: brew install libpcap
- name: Install modules
run: go mod tidy

View File

@@ -46,7 +46,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -40,11 +40,12 @@
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open-Source Network Security in a Single Platform
![download (2)](https://github.com/netbirdio/netbird/assets/700848/16210ac2-7265-44c1-8d4e-8fae85534dac)
![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f)
### Key features
@@ -76,7 +77,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
@@ -93,9 +94,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
@@ -119,7 +120,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
### Legal
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -64,6 +64,10 @@ var installCmd = &cobra.Command{
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), svcConfig)

View File

@@ -143,7 +143,7 @@ func (m *Manager) AllowNetbird() error {
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
log.Debugf("allow netbird rule already exists: %#v", rule)
return nil
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"runtime"
"runtime/debug"
"strings"
"time"
@@ -93,7 +95,13 @@ func runClient(
relayProbe *Probe,
wgProbe *Probe,
) error {
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
defer func() {
if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
}
}()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.

View File

@@ -93,6 +93,10 @@ type Engine struct {
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn
beforePeerHook peer.BeforeAddPeerHookFunc
afterPeerHook peer.AfterRemovePeerHookFunc
// rpManager is a Rosenpass manager
rpManager *rosenpass.Manager
@@ -134,6 +138,7 @@ type Engine struct {
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
turnRelay *relay.PermanentTurn
}
// Peer is an instance of the Connection Peer
@@ -195,7 +200,7 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
wgProxyFactory: &wgproxy.Factory{},
mgmProbe: mgmProbe,
signalProbe: signalProbe,
relayProbe: relayProbe,
@@ -260,9 +265,14 @@ func (e *Engine) Start() error {
e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
if err := e.routeManager.Init(); err != nil {
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
}
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
err = e.wgInterfaceCreate()
@@ -443,10 +453,19 @@ func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKe
t = sProto.Body_OFFER
}
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
}, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
msg, err := signal.MarshalCredential(
myKey,
offerAnswer.WgListenPort,
remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
},
t,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelayedAddr.String(),
offerAnswer.RemoteAddr.String(),
)
if err != nil {
return err
}
@@ -474,6 +493,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
turnRelay := relay.NewPermanentTurn(e.STUNs[0], e.TURNs[0])
err = turnRelay.Open()
if err != nil {
return fmt.Errorf("faile to open turn relay: %w", err)
}
e.turnRelay = turnRelay
e.wgInterface.SetRelayConn(e.turnRelay.RelayConn())
// todo update signal
}
@@ -612,6 +639,7 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
var newTURNs []*stun.URI
log.Debugf("got TURNs update from Management Service, updating")
for _, turn := range turns {
log.Debugf("-----updated Turn %v, %s, %s", turn.HostConfig.Uri, turn.User, turn.Password)
url, err := stun.ParseURI(turn.HostConfig.Uri)
if err != nil {
return err
@@ -621,7 +649,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
newTURNs = append(newTURNs, url)
}
e.TURNs = newTURNs
return nil
}
@@ -785,6 +812,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
}
e.statusRecorder.ReplaceOfflinePeers(replacement)
@@ -808,10 +836,15 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
if _, ok := e.peerConns[peerKey]; !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil {
return err
return fmt.Errorf("create peer connection: %w", err)
}
e.peerConns[peerKey] = conn
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
@@ -919,7 +952,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
RosenpassAddr: e.getRosenpassAddr(),
}
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover, e.turnRelay)
if err != nil {
return nil, err
}
@@ -985,6 +1018,17 @@ func (e *Engine) receiveSignalEvents() {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
relayedAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetRelayedAddress())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetSrvRefAddress())
if err != nil {
return err
}
conn.OnRemoteOffer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
@@ -994,6 +1038,8 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelayedAddr: relayedAddr,
RemoteAddr: remoteAddr,
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1009,6 +1055,17 @@ func (e *Engine) receiveSignalEvents() {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
relayedAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetRelayedAddress())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetSrvRefAddress())
if err != nil {
return err
}
conn.OnRemoteAnswer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
@@ -1018,6 +1075,8 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelayedAddr: relayedAddr,
RemoteAddr: remoteAddr,
})
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
@@ -1028,7 +1087,6 @@ func (e *Engine) receiveSignalEvents() {
conn.OnRemoteCandidate(candidate)
case sProto.Body_MODE:
}
return nil
})
if err != nil {
@@ -1100,11 +1158,17 @@ func (e *Engine) close() {
log.Errorf("failed closing ebpf proxy: %s", err)
}
e.turnRelay.Close()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil {
e.dnsServer.Stop()
}
if e.routeManager != nil {
e.routeManager.Stop()
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
@@ -1119,10 +1183,6 @@ func (e *Engine) close() {
}
}
if e.routeManager != nil {
e.routeManager.Stop()
}
if e.firewall != nil {
err := e.firewall.Reset()
if err != nil {

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"runtime"
"strings"
"sync"
"time"
@@ -14,18 +13,22 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
defaultWgKeepAlive = 25 * time.Second
)
@@ -90,6 +93,10 @@ type OfferAnswer struct {
// RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
// This value is the local Rosenpass server address when sending the message
RosenpassAddr string
// Turn Relay
RelayedAddr net.Addr
RemoteAddr net.Addr
}
// IceCredentials ICE protocol credentials struct
@@ -98,6 +105,9 @@ type IceCredentials struct {
Pwd string
}
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
type Conn struct {
config ConnConfig
mu sync.Mutex
@@ -135,7 +145,11 @@ type Conn struct {
sentExtraSrflx bool
remoteEndpoint *net.UDPAddr
remoteConn *ice.Conn
connID nbnet.ConnectionID
beforeAddPeerHooks []BeforeAddPeerHookFunc
afterRemovePeerHooks []AfterRemovePeerHookFunc
turnRelay *relay.PermanentTurn
}
// meta holds meta information about a connection
@@ -166,7 +180,7 @@ func (conn *Conn) UpdateStunTurn(turnStun []*stun.URI) {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, turnRelay *relay.PermanentTurn) (*Conn, error) {
return &Conn{
config: config,
mu: sync.Mutex{},
@@ -179,6 +193,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
wgProxyFactory: wgProxyFactory,
adapter: adapter,
iFaceDiscover: iFaceDiscover,
turnRelay: turnRelay,
}, nil
}
@@ -196,20 +211,22 @@ func (conn *Conn) reCreateAgent() error {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn,
CandidateTypes: conn.candidateTypes(),
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs,
Net: transportNet,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{},
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs,
Net: transportNet,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
}
if conn.config.DisableIPv6Discovery {
@@ -217,7 +234,6 @@ func (conn *Conn) reCreateAgent() error {
}
conn.agent, err = ice.NewAgent(agentConfig)
if err != nil {
return err
}
@@ -251,17 +267,6 @@ func (conn *Conn) reCreateAgent() error {
return nil
}
func (conn *Conn) candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
// Open opens connection to the remote peer starting ICE candidate gathering process.
// Blocks until connection has been closed or connection timeout.
// ConnStatus will be set accordingly
@@ -273,6 +278,7 @@ func (conn *Conn) Open() error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
ConnStatus: conn.status,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -332,48 +338,60 @@ func (conn *Conn) Open() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
}
err = conn.agent.GatherCandidates()
if err != nil {
return err
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
isControlling := conn.config.LocalKey > conn.config.Key
var remoteConn *ice.Conn
isControlling := conn.config.LocalKey < conn.config.Key
if isControlling {
remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
log.Debugf("---- use this peer's tunr connection")
err = conn.turnRelay.PunchHole(remoteOfferAnswer.RemoteAddr)
if err != nil {
log.Errorf("failed to punch hole: %v", err)
}
addr, ok := remoteOfferAnswer.RemoteAddr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("failed to cast addr to udp addr")
}
addr.Port = remoteOfferAnswer.WgListenPort
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
// todo close
return err
}
} else {
remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
if err != nil {
return err
}
log.Debugf("---- use remote peer tunr connection")
addr, ok := remoteOfferAnswer.RelayedAddr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("failed to cast addr to udp addr")
}
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
// todo close
return err
}
// dynamically set remote WireGuard port is other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
// the ice connection has been established successfully so we are ready to start the proxy
/*
remoteAddr, err := conn.configureConnection(remoteOfferAnswer.RelayedAddr, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
if err != nil {
return err
}
*/
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, addr.String())
}
conn.remoteConn = remoteConn
// the ice connection has been established successfully so we are ready to start the proxy
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
if err != nil {
return err
}
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select {
case <-conn.closeCh:
@@ -389,34 +407,33 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
conn.mu.Lock()
defer conn.mu.Unlock()
pair, err := conn.agent.GetSelectedCandidatePair()
if err != nil {
return nil, err
}
var endpoint net.Addr
if isRelayCandidate(pair.Local) {
log.Debugf("setup relay connection")
conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil {
return nil, err
}
} else {
// To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
endpoint = remoteConn.RemoteAddr()
}
endpoint = remoteConn.RemoteAddr()
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.remoteEndpoint = endpointUdpAddr
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
conn.connID = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
log.Errorf("Before add peer hook failed: %v", err)
}
}
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
@@ -425,30 +442,33 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
conn.status = StatusConnected
rosenpassEnabled := false
if remoteRosenpassPubKey != nil {
rosenpassEnabled = true
}
/*
rosenpassEnabled := false
if remoteRosenpassPubKey != nil {
rosenpassEnabled = true
}
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err)
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err)
}
*/
_, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps)
if err != nil {
@@ -506,6 +526,15 @@ func (conn *Conn) cleanup() error {
// todo: is it problem if we try to remove a peer what is never existed?
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if conn.connID != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connID); err != nil {
log.Errorf("After remove peer hook failed: %v", err)
}
}
}
conn.connID = ""
if conn.notifyDisconnected != nil {
conn.notifyDisconnected()
conn.notifyDisconnected = nil
@@ -521,6 +550,7 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -640,6 +670,8 @@ func (conn *Conn) sendAnswer() error {
Version: version.NetbirdVersion(),
RosenpassPubKey: conn.config.RosenpassPubKey,
RosenpassAddr: conn.config.RosenpassAddr,
RelayedAddr: conn.turnRelay.RelayedAddress(),
RemoteAddr: conn.turnRelay.SrvRefAddr(),
})
if err != nil {
return err
@@ -663,6 +695,8 @@ func (conn *Conn) sendOffer() error {
Version: version.NetbirdVersion(),
RosenpassPubKey: conn.config.RosenpassPubKey,
RosenpassAddr: conn.config.RosenpassAddr,
RelayedAddr: conn.turnRelay.RelayedAddress(),
RemoteAddr: conn.turnRelay.SrvRefAddr(),
})
if err != nil {
return err
@@ -702,6 +736,10 @@ func (conn *Conn) Status() ConnStatus {
return conn.status
}
func (conn *Conn) OnRemoteRelayRequest(relayedAddr string, remoteIP string) {
}
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {

View File

@@ -10,9 +10,10 @@ import (
)
const (
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC"
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
)
func iceKeepAlive() time.Duration {
@@ -21,7 +22,7 @@ func iceKeepAlive() time.Duration {
return iceKeepAliveDefault
}
log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv)
log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv)
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
@@ -37,7 +38,7 @@ func iceDisconnectedTimeout() time.Duration {
return iceDisconnectedTimeoutDefault
}
log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
@@ -47,6 +48,22 @@ func iceDisconnectedTimeout() time.Duration {
return time.Duration(disconnectedTimeoutSec) * time.Second
}
func iceRelayAcceptanceMinWait() time.Duration {
iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec)
if iceRelayAcceptanceMinWaitEnv == "" {
return iceRelayAcceptanceMinWaitDefault
}
log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv)
disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault)
return iceRelayAcceptanceMinWaitDefault
}
return time.Duration(disconnectedTimeoutSec) * time.Second
}
func hasICEForceRelayConn() bool {
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
return strings.ToLower(disconnectedTimeoutEnv) == "true"

View File

@@ -14,6 +14,7 @@ import (
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
IP string
PubKey string
FQDN string
@@ -30,7 +31,38 @@ type State struct {
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
Routes map[string]struct{}
routes map[string]struct{}
}
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
}
// LocalPeerState contains the latest state of the local peer
@@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
Mux: new(sync.RWMutex),
}
d.peerListChangedForNotification = true
return nil
@@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.Routes != nil {
peerState.Routes = receivedState.Routes
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState, peerState)
@@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true
}
return false
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"errors"
"testing"
"sync"
"github.com/stretchr/testify/assert"
)
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState

View File

@@ -0,0 +1,143 @@
package relay
import (
"fmt"
"io"
"net"
"os"
"time"
log "github.com/sirupsen/logrus"
)
const (
bufferSize = 800
testFile = "/tmp/1MB"
)
type Speed struct {
}
func NewSpeed() *Speed {
return &Speed{}
}
func (s *Speed) ReceiveFileFromAddr(remoteAddr net.Addr) error {
pc, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
log.Errorf("failed to lisen: %s", err.Error())
return err
}
defer pc.Close()
log.Debugf("--- sending initial message to: %s", remoteAddr.String())
_, err = pc.WriteTo([]byte("hey, I am the receiver"), remoteAddr)
if err != nil {
log.Errorf("failed to send initial msg: %s", err.Error())
return err
}
return s.receiveFile(pc)
}
func (s *Speed) ReceiveFileFromPC(pc net.PacketConn) error {
return s.receiveFile(pc)
}
func (s *Speed) receiveFile(pc net.PacketConn) error {
log.Debugf("--- start to receive file...")
file, err := os.OpenFile(fmt.Sprintf("%s.cp", testFile), os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Errorf("failed to open file: %s", err.Error())
return err
}
_ = file.Truncate(0)
defer file.Close()
buffer := make([]byte, bufferSize)
for {
n, addr, err := pc.ReadFrom(buffer)
if err != nil {
log.Errorf("failed to read from connection: %s", err.Error())
return err
}
n, err = file.Write(buffer[:n])
if err != nil {
log.Errorf("failed to write to file: %s", err.Error())
return err
}
_, err = pc.WriteTo([]byte("ack"), addr)
if err != nil {
log.Errorf("failed to send ack: %s", err.Error())
}
log.Debugf("received %d bytes from %s", n, addr)
}
}
func (s *Speed) SendFileToPC(relayConn net.PacketConn) error {
buf := make([]byte, bufferSize)
log.Debugf("--- wait for initial message")
n, rAddr, err := relayConn.ReadFrom(buf)
if err != nil {
log.Errorf("failed to read from connection: %s", err.Error())
return err
}
log.Errorf("received initial msg %d bytes (%s), addr %s", n, string(buf[:n]), rAddr.String())
return s.sendFile(relayConn, rAddr)
}
func (s *Speed) SendFileToAddr(addr net.Addr) error {
pc, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
log.Errorf("failed to lisen: %s", err.Error())
return err
}
defer pc.Close()
return s.sendFile(pc, addr)
}
func (s *Speed) sendFile(conn net.PacketConn, rAddr net.Addr) error {
log.Debugf("--- start to send file...")
file, err := os.Open(testFile)
if err != nil {
// Handle error
return nil
}
defer file.Close()
buf := make([]byte, bufferSize)
start := time.Now()
sent := 0
for {
n, err := file.Read(buf)
if err != nil && err != io.EOF {
return err
}
if n == 0 {
break
}
n, err = conn.WriteTo(buf[:n], rAddr)
if err != nil {
log.Errorf("failed to write to connection: %s", err.Error())
return err
}
sent += n
log.Debugf("sent %d bytes, (%d) to %s", n, sent, rAddr.String())
// wait for ack
_, _, err = conn.ReadFrom(make([]byte, bufferSize))
if err != nil {
log.Errorf("failed to read from connection: %s", err.Error())
return err
}
}
elapsed := time.Since(start)
log.Infof("sent %d bytes, troughtput: %f MB/s", sent, float64(sent)/1024/1024/elapsed.Seconds())
return nil
}

View File

@@ -0,0 +1,198 @@
package relay
import (
"fmt"
"math"
"net"
"sync"
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/turn/v3"
log "github.com/sirupsen/logrus"
)
type PermanentTurn struct {
stunURI *stun.URI
turnURI *stun.URI
stunConn net.PacketConn
turnClient *turn.Client
turnClientListenLock sync.Mutex
relayConn net.PacketConn // represents the remote socket.
srvReflexiveAddress *net.UDPAddr
}
func NewPermanentTurn(stunURL, turnURL *stun.URI) *PermanentTurn {
return &PermanentTurn{
stunURI: stunURL,
turnURI: turnURL,
}
}
func (r *PermanentTurn) Open() error {
stunConn, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
return err
}
r.stunConn = stunConn
cfg := &turn.ClientConfig{
STUNServerAddr: toURL(r.stunURI),
TURNServerAddr: toURL(r.turnURI),
Conn: stunConn,
Username: r.turnURI.Username,
Password: r.turnURI.Password,
LoggerFactory: logging.NewDefaultLoggerFactory(),
}
client, err := turn.NewClient(cfg)
if err != nil {
log.Errorf("failed to create turn client: %v", err)
return err
}
r.turnClient = client
err = r.turnClient.Listen()
if err != nil {
log.Errorf("failed to listen turn client: %v", err)
return err
}
relayConn, err := client.Allocate()
if err != nil {
log.Errorf("failed to allocate relay connection: %v", err)
return err
}
r.relayConn = relayConn
srvReflexiveAddress, err := r.discoverPublicIPByStun()
if err != nil {
log.Errorf("failed to discover public IP: %v", err)
return err
}
r.srvReflexiveAddress = srvReflexiveAddress
return nil
}
func (r *PermanentTurn) RelayedAddress() net.Addr {
return r.relayConn.LocalAddr()
}
func (r *PermanentTurn) SrvRefAddr() net.Addr {
return r.srvReflexiveAddress
}
func (r *PermanentTurn) PunchHole(mappedAddr net.Addr) error {
/*
err := r.turnClient.CreatePermission(mappedAddr)
if err != nil {
log.Errorf("---- failed to create permission: %v", err)
return err
}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.Fingerprint,
)
if err != nil {
log.Errorf("--- failed to build stun message: %v", err)
return nil
}
_, err = r.relayConn.WriteTo(msg.Raw, mappedAddr)
if err != nil {
log.Errorf("failed to write to relay conn: %v", err)
return err
}
*/
_, err := r.relayConn.WriteTo([]byte("Hello"), mappedAddr)
return err
}
func (r *PermanentTurn) RelayConn() net.PacketConn {
return r.relayConn
}
func (r *PermanentTurn) Close() {
r.turnClient.Close()
err := r.relayConn.Close()
if err != nil {
log.Errorf("failed to close relayConn: %s", err.Error())
}
err = r.stunConn.Close()
if err != nil {
log.Errorf("failed to close stunConn: %s", err.Error())
}
}
func (r *PermanentTurn) discoverPublicIP() (*net.UDPAddr, error) {
addr, err := r.turnClient.SendBindingRequest()
if err != nil {
log.Errorf("failed to send binding request: %v", err)
return nil, err
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return nil, fmt.Errorf("failed to cast addr to udp addr")
}
return udpAddr, nil
}
func (r *PermanentTurn) discoverPublicIPByStun() (*net.UDPAddr, error) {
c, err := stun.DialURI(r.stunURI, &stun.DialConfig{})
if err != nil {
panic(err)
}
message := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
var addr *net.UDPAddr
err = c.Do(message, func(res stun.Event) {
if res.Error != nil {
panic(res.Error)
}
var xorAddr stun.XORMappedAddress
if err := xorAddr.GetFrom(res.Message); err != nil {
log.Errorf("failed to get xor address: %v", err)
return
}
addr = &net.UDPAddr{
IP: xorAddr.IP,
Port: xorAddr.Port,
}
})
if err != nil {
return nil, err
}
return addr, nil
}
func (r *PermanentTurn) listen() {
if !r.turnClientListenLock.TryLock() {
return
}
go func() {
defer r.turnClientListenLock.Unlock()
buf := make([]byte, math.MaxUint16)
for {
n, from, err := r.stunConn.ReadFrom(buf)
if err != nil {
log.Errorf("Failed to read from stun conn. Exiting loop %v", err)
break
}
_, err = r.turnClient.HandleInbound(buf[:n], from)
if err != nil {
log.Errorf("Failed to handle inbound turn message: %s. Exiting loop", err)
break
}
}
}()
}
func toURL(uri *stun.URI) string {
return fmt.Sprintf("%s:%d", uri.Host, uri.Port)
}

View File

@@ -0,0 +1,137 @@
package relay
import (
"os"
"sync"
"testing"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
const (
userName = "1714092678"
password = "8PEprGKo+UARpYpQOulNz3H24dI="
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
func TestMyTurnUpload(t *testing.T) {
turnURI, err := stun.ParseURI("turn:api.stage.netbird.io:3478?transport=udp")
if err != nil {
t.Fatalf("failed to parse stun url: %v", err)
}
turnURI.Username = userName
turnURI.Password = password
stunURI, err := stun.ParseURI("stun:api.stage.netbird.io:3478")
if err != nil {
t.Fatalf("failed to parse stun url: %v", err)
}
turnRelayA := NewPermanentTurn(stunURI, turnURI)
err = turnRelayA.Open()
if err != nil {
t.Fatalf("failed to open turn relay: %v", err)
}
defer turnRelayA.Close()
turnRelayB := NewPermanentTurn(stunURI, turnURI)
peerBAddr, err := turnRelayB.discoverPublicIPByStun()
if err != nil {
t.Fatalf("failed to discover public ip: %v", err)
}
err = turnRelayA.PunchHole(peerBAddr)
if err != nil {
t.Fatalf("failed to punch hole: %v", err)
}
// at this point, the relayed side should be established
wg := sync.WaitGroup{}
wg.Add(2)
speedB := NewSpeed()
go func() {
err := speedB.ReceiveFileFromAddr(turnRelayA.relayConn.LocalAddr())
if err != nil {
log.Errorf("failed to receive file: %v", err)
}
wg.Done()
}()
speedA := NewSpeed()
go func() {
err := speedA.SendFileToPC(turnRelayA.relayConn)
if err != nil {
log.Errorf("failed to send file: %v", err)
}
log.Debugf("file sent")
wg.Done()
}()
wg.Wait()
}
func TestMyTurnDownload(t *testing.T) {
turnURI, err := stun.ParseURI("turn:api.stage.netbird.io:3478?transport=udp")
if err != nil {
t.Fatalf("failed to parse stun url: %v", err)
}
turnURI.Username = "1714016034"
turnURI.Password = "oDpL6tDu0d+xcO3rQnHoEvbcS/Q="
stunURI, err := stun.ParseURI("stun:api.stage.netbird.io:3478")
if err != nil {
t.Fatalf("failed to parse stun url: %v", err)
}
turnRelayA := NewPermanentTurn(stunURI, turnURI)
err = turnRelayA.Open()
if err != nil {
t.Fatalf("failed to open turn relay: %v", err)
}
defer turnRelayA.Close()
turnRelayB := NewPermanentTurn(stunURI, turnURI)
peerBAddr, err := turnRelayB.discoverPublicIPByStun()
if err != nil {
t.Fatalf("failed to discover public ip: %v", err)
}
err = turnRelayA.PunchHole(peerBAddr)
if err != nil {
t.Fatalf("failed to punch hole: %v", err)
}
// at this point, the relayed side should be established
wg := sync.WaitGroup{}
wg.Add(2)
speedB := NewSpeed()
go func() {
err := speedB.SendFileToAddr(turnRelayA.relayConn.LocalAddr())
if err != nil {
log.Errorf("failed to receive file: %v", err)
}
wg.Done()
}()
speedA := NewSpeed()
go func() {
err := speedA.ReceiveFileFromPC(turnRelayA.relayConn)
if err != nil {
log.Errorf("failed to send file: %v", err)
}
log.Debugf("file sent")
wg.Done()
}()
wg.Wait()
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
@@ -18,6 +19,7 @@ type routerPeerStatus struct {
connected bool
relayed bool
direct bool
latency time.Duration
}
type routesUpdate struct {
@@ -68,6 +70,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
connected: peerStatus.ConnStatus == peer.StatusConnected,
relayed: peerStatus.Relayed,
direct: peerStatus.Direct,
latency: peerStatus.Latency,
}
}
return routePeerStatuses
@@ -83,11 +86,13 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
// * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
// * Latency: Routes with lower latency are prioritized.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
chosen := ""
chosenScore := 0
chosenScore := float64(0)
currScore := float64(0)
currID := ""
if c.chosenRoute != nil {
@@ -95,7 +100,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
}
for _, r := range c.routes {
tempScore := 0
tempScore := float64(0)
peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected {
continue
@@ -103,9 +108,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
if r.Metric < route.MaxMetric {
metricDiff := route.MaxMetric - r.Metric
tempScore = metricDiff * 10
tempScore = float64(metricDiff) * 10
}
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
}
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
tempScore++
}
@@ -114,7 +128,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
tempScore++
}
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
chosen = r.ID
chosenScore = tempScore
}
@@ -123,18 +137,26 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
chosen = r.ID
chosenScore = tempScore
}
if r.ID == currID {
currScore = tempScore
}
}
if chosen == "" {
switch {
case chosen == "":
var peers []string
for _, r := range c.routes {
peers = append(peers, r.Peer)
}
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
} else if chosen != currID {
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
case chosen != currID:
if currScore != 0 && currScore < chosenScore+0.1 {
return currID
} else {
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
}
}
return chosen
@@ -174,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
return fmt.Errorf("get peer state: %v", err)
}
delete(state.Routes, c.network.String())
state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
@@ -193,7 +215,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil {
if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
}
@@ -234,7 +256,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// otherwise add the route to the system
if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err)
}
@@ -246,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[c.network.String()] = struct{}{}
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}

View File

@@ -3,6 +3,7 @@ package routemanager
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/route"
)
@@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
name string
statuses map[string]routerPeerStatus
expectedRouteID string
currentRoute *route.Route
currentRoute string
existingRoutes map[string]*route.Route
}{
{
@@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "",
},
{
@@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
@@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: nil,
currentRoute: "",
expectedRouteID: "route1",
},
{
name: "multiple connected peers with different latencies",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
latency: 300 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "should ignore routes with latency 0",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 12 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current chosen route doesn't exist anymore",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 20 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "routeDoesntExistAnymore",
expectedRouteID: "route2",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
ID: "routeDoesntExistAnymore",
}
if tc.currentRoute != "" {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
// create new clientNetwork
client := &clientNetwork{
network: netip.MustParsePrefix("192.168.0.0/24"),
routes: tc.existingRoutes,
chosenRoute: tc.currentRoute,
chosenRoute: currentRoute,
}
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@@ -3,7 +3,9 @@ package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"runtime"
"sync"
@@ -24,7 +26,7 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface
type Manager interface {
Init() error
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -65,16 +67,21 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
}
// Init sets up the routing
func (m *DefaultManager) Init() error {
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
if err := setupRouting(); err != nil {
return fmt.Errorf("setup routing: %w", err)
mgmtAddress := m.statusRecorder.GetManagementState().URL
signalAddress := m.statusRecorder.GetSignalState().URL
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
}
log.Info("Routing setup complete")
return nil
return beforePeerHook, afterPeerHook, nil
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
@@ -203,16 +210,36 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
}
func isPrefixSupported(prefix netip.Prefix) bool {
if runtime.GOOS == "linux" {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management
if prefix.Bits() < minRangeBits {
if prefix.Bits() <= minRangeBits {
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
version.NetbirdVersion(), prefix)
return false
}
return true
}
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
func resolveURLsToIPs(urls []string) []net.IP {
var ips []net.IP
for _, rawurl := range urls {
u, err := url.Parse(rawurl)
if err != nil {
log.Errorf("Failed to parse url %s: %v", rawurl, err)
continue
}
ipAddrs, err := net.LookupIP(u.Hostname())
if err != nil {
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
continue
}
ips = append(ips, ipAddrs...)
}
return ips
}

View File

@@ -28,14 +28,14 @@ const remotePeerKey2 = "remote1"
func TestManagerUpdateRoutes(t *testing.T) {
testCases := []struct {
name string
inputInitRoutes []*route.Route
inputRoutes []*route.Route
inputSerial uint64
removeSrvRouter bool
serverRoutesExpected int
clientNetworkWatchersExpected int
clientNetworkWatchersExpectedLinux int
name string
inputInitRoutes []*route.Route
inputRoutes []*route.Route
inputSerial uint64
removeSrvRouter bool
serverRoutesExpected int
clientNetworkWatchersExpected int
clientNetworkWatchersExpectedAllowed int
}{
{
name: "Should create 2 client networks",
@@ -201,9 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 0,
clientNetworkWatchersExpectedLinux: 1,
inputSerial: 1,
clientNetworkWatchersExpected: 0,
clientNetworkWatchersExpectedAllowed: 1,
},
{
name: "Remove 1 Client Route",
@@ -417,7 +417,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
err = routeManager.Init()
_, _, err = routeManager.Init()
require.NoError(t, err, "should init route manager")
defer routeManager.Stop()
@@ -434,8 +436,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected
if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 {
expectedWatchers = testCase.clientNetworkWatchersExpectedLinux
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
}
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")

View File

@@ -6,6 +6,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
@@ -16,8 +17,8 @@ type MockManager struct {
StopFunc func()
}
func (m *MockManager) Init() error {
return nil
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface

View File

@@ -0,0 +1,126 @@
//go:build !android && !ios
package routemanager
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
type ref struct {
count int
nexthop netip.Addr
intf string
}
type RouteManager struct {
// refCountMap keeps track of the reference ref for prefixes
refCountMap map[netip.Prefix]ref
// prefixMap keeps track of the prefixes associated with a connection ID for removal
prefixMap map[nbnet.ConnectionID][]netip.Prefix
addRoute AddRouteFunc
removeRoute RemoveRouteFunc
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap
return &RouteManager{
refCountMap: map[netip.Prefix]ref{},
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
addRoute: addRoute,
removeRoute: removeRoute,
}
}
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
ref := rm.refCountMap[prefix]
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
// Add route to the system, only if it's a new prefix
if ref.count == 0 {
log.Debugf("Adding route for prefix %s", prefix)
nexthop, intf, err := rm.addRoute(prefix)
if errors.Is(err, ErrRouteNotFound) {
return nil
}
if errors.Is(err, ErrRouteNotAllowed) {
log.Debugf("Adding route for prefix %s: %s", prefix, err)
}
if err != nil {
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
}
ref.nexthop = nexthop
ref.intf = intf
}
ref.count++
rm.refCountMap[prefix] = ref
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
return nil
}
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
prefixes, ok := rm.prefixMap[connID]
if !ok {
log.Debugf("No prefixes found for connection ID %s", connID)
return nil
}
var result *multierror.Error
for _, prefix := range prefixes {
ref := rm.refCountMap[prefix]
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
if ref.count == 1 {
log.Debugf("Removing route for prefix %s", prefix)
// TODO: don't fail if the route is not found
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
continue
}
delete(rm.refCountMap, prefix)
} else {
ref.count--
rm.refCountMap[prefix] = ref
}
}
delete(rm.prefixMap, connID)
return result.ErrorOrNil()
}
// Flush removes all references and routes from the system
func (rm *RouteManager) Flush() error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
var result *multierror.Error
for prefix := range rm.refCountMap {
log.Debugf("Removing route for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]ref{}
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
return result.ErrorOrNil()
}

View File

@@ -155,11 +155,13 @@ func (m *defaultServerRouter) cleanUp() {
log.Errorf("Failed to remove cleanup route: %v", err)
}
state := m.statusRecorder.GetLocalPeerState()
state.Routes = nil
m.statusRecorder.UpdateLocalPeerState(state)
}
state := m.statusRecorder.GetLocalPeerState()
state.Routes = nil
m.statusRecorder.UpdateLocalPeerState(state)
}
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
parsed, err := netip.ParsePrefix(source)
if err != nil {

View File

@@ -0,0 +1,428 @@
//go:build !android && !ios
package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRouteNotFound = errors.New("route not found")
var ErrRouteNotAllowed = errors.New("route not allowed")
// TODO: fix: for default our wg address now appears as the default gw
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
defaultGateway, _, err := getNextHop(addr)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(defaultGateway) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
if defaultGateway.Is6() {
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
var exitIntf string
gatewayHop, intf, err := getNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
if intf != nil {
exitIntf = intf.Name
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
}
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
r, err := netroute.New()
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Warnf("Failed to get route for %s: %v", ip, err)
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
}
return addr.Unmap(), intf, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
}
return addr, intf, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
}
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(
prefix netip.Prefix,
vpnIntf *iface.WGIface,
initialNextHop netip.Addr,
initialIntf *net.Interface,
) (netip.Addr, string, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, "", ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := getNextHop(addr)
if err != nil {
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
var exitIntf string
if intf != nil {
exitIntf = intf.Name
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
if initialIntf != nil {
exitIntf = initialIntf.Name
}
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
}
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == defaultv6 {
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, netip.Addr{}, intf)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
} else if prefix == defaultv6 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
return removeFromRouteTable(prefix, netip.Addr{}, intf)
}
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return &prefix, nil
}
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, string, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {
nexthop, intf = initialNextHopV6, initialIntfV6
}
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
},
removeFromRouteTable,
)
return setupHooks(*routeManager, initAddresses)
}
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
if routeManager == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := routeManager.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := getPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := routeManager.RemoveRouteRef(connID); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return result.ErrorOrNil()
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}

View File

@@ -1,13 +1,33 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, string) error {
return nil
}
func removeVPNRoute(netip.Prefix, string) error {
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"net/netip"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
)
@@ -51,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
continue
}
if len(m.Addrs) < 3 {
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
continue
}
addr, ok := toNetIPAddr(m.Addrs[0])
if !ok {
continue
}
mask, ok := toNetIPMASK(m.Addrs[2])
if !ok {
continue
cidr := 32
if mask := m.Addrs[2]; mask != nil {
cidr, ok = toCIDR(mask)
if !ok {
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
continue
}
}
cidr, _ := mask.Size()
routePrefix := netip.PrefixFrom(addr, cidr)
if routePrefix.IsValid() {
@@ -73,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
switch t := a.(type) {
case *route.Inet4Addr:
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
addr := netip.MustParseAddr(ip.String())
return addr, true
return netip.AddrFrom4(t.IP), true
default:
return netip.Addr{}, false
}
}
func toNetIPMASK(a route.Addr) (net.IPMask, bool) {
func toCIDR(a route.Addr) (int, bool) {
switch t := a.(type) {
case *route.Inet4Addr:
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
return mask, true
cidr, _ := mask.Size()
return cidr, true
default:
return nil, false
return 0, false
}
}

View File

@@ -1,13 +0,0 @@
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios
package routemanager
import "net/netip"
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
}

View File

@@ -0,0 +1,89 @@
//go:build darwin && !ios
package routemanager
import (
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return routeCmd("add", prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return routeCmd("delete", prefix, nexthop, intf)
}
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
if prefix.Addr().Is6() {
inet = "-inet6"
// Special case for IPv6 split default route, pointing to the wg interface fails
// TODO: Remove once we have IPv6 support on the interface
if prefix.Bits() == 1 {
intf = "lo0"
}
}
args := []string{"-n", action, inet, network}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else if intf != "" {
args = append(args, "-interface", intf)
}
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return fmt.Errorf("route cmd retry failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,138 @@
//go:build !ios
package routemanager
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var expectedVPNint = "utun100"
var expectedExternalInt = "lo0"
var expectedInternalInt = "lo0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
destination: "10.10.0.2:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := "lo0"
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return "lo0"
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
}

View File

@@ -1,13 +1,33 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, string) error {
return nil
}
func removeVPNRoute(netip.Prefix, string) error {
return nil
}

View File

@@ -4,17 +4,21 @@ package routemanager
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -33,6 +37,9 @@ const (
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
type ruleParams struct {
fwmark int
tableID int
@@ -64,7 +71,12 @@ func getSetupRules() []ruleParams {
// enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting() (err error) {
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy {
log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
@@ -80,17 +92,26 @@ func setupRouting() (err error) {
rules := getSetupRules()
for _, rule := range rules {
if err := addRule(rule); err != nil {
return fmt.Errorf("%s: %w", rule.description, err)
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
isLegacy = true
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
}
}
return nil
return nil, nil, nil
}
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error {
if isLegacy {
return cleanupRoutingWithRouteManager(routeManager)
}
var result *multierror.Error
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
@@ -102,7 +123,7 @@ func cleanupRouting() error {
rules := getSetupRules()
for _, rule := range rules {
if err := removeAllRules(rule); err != nil {
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
}
@@ -110,53 +131,104 @@ func cleanupRouting() error {
return result.ErrorOrNil()
}
func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error {
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 2
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func addVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
return genericAddVPNRoute(prefix, intf)
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err)
}
}
if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add route: %w", err)
}
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error {
func removeVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
return genericRemoveVPNRoute(prefix, intf)
}
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err)
}
}
if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func getRoutesFromTable() ([]netip.Prefix, error) {
return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4)
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
if err != nil {
return nil, fmt.Errorf("get v4 routes: %w", err)
}
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
if err != nil {
return nil, fmt.Errorf("get v6 routes: %w", err)
}
return append(v4Routes, v6Routes...), nil
}
// getRoutes fetches routes from a specific routing table identified by tableID.
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
var prefixList []netip.Prefix
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
}
for _, route := range routes {
if route.Dst != nil {
addr, ok := netip.AddrFromSlice(route.Dst.IP)
if !ok {
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
}
ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones)
if prefix.IsValid() {
prefixList = append(prefixList, prefix)
}
}
}
return prefixList, nil
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: family,
Family: getAddressFamily(prefix),
}
if prefix != nil {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route.Dst = ipNet
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route.Dst = ipNet
if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
@@ -172,7 +244,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
// tableID specifies the routing table to which the unreachable route will be added.
func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@@ -181,7 +253,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: ipFamily,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
@@ -192,7 +264,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
return nil
}
func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@@ -201,7 +273,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: ipFamily,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
@@ -214,7 +286,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@@ -223,7 +295,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int)
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: family,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
@@ -263,34 +335,6 @@ func flushRoutes(tableID, family int) error {
return result.ErrorOrNil()
}
// getRoutes fetches routes from a specific routing table identified by tableID.
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
var prefixList []netip.Prefix
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
}
for _, route := range routes {
if route.Dst != nil {
addr, ok := netip.AddrFromSlice(route.Dst.IP)
if !ok {
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
}
ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones)
if prefix.IsValid() {
prefixList = append(prefixList, prefix)
}
}
}
return prefixList, nil
}
func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
@@ -402,7 +446,7 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleDel(rule); err != nil {
return fmt.Errorf("remove routing rule: %w", err)
}
@@ -410,35 +454,58 @@ func removeRule(params ruleParams) error {
}
func removeAllRules(params ruleParams) error {
for {
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) {
break
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
for {
if ctx.Err() != nil {
done <- ctx.Err()
return
}
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
done <- nil
return
}
done <- err
return
}
return err
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
return nil
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr *string, intf *string, route *netlink.Route) error {
if addr != nil {
ip := net.ParseIP(*addr)
if ip == nil {
return fmt.Errorf("parsing address %s failed", *addr)
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() {
route.Gw = addr.AsSlice()
if intf == "" {
intf = addr.Zone()
}
route.Gw = ip
}
if intf != nil {
link, err := netlink.LinkByName(*intf)
if intf != "" {
link, err := netlink.LinkByName(intf)
if err != nil {
return fmt.Errorf("set interface %s: %w", *intf, err)
return fmt.Errorf("set interface %s: %w", intf, err)
}
route.LinkIndex = link.Attrs().Index
}
return nil
}
func getAddressFamily(prefix netip.Prefix) int {
if prefix.Addr().Is4() {
return netlink.FAMILY_V4
}
return netlink.FAMILY_V6
}

View File

@@ -6,34 +6,38 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"os"
"strings"
"syscall"
"testing"
"time"
"github.com/gopacket/gopacket"
"github.com/gopacket/gopacket/layers"
"github.com/gopacket/gopacket/pcap"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
var expectedVPNint = "wgtest0"
var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via physical interface",
destination: "10.10.0.2:53",
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.1:53",
expectedInterface: expectedLoopbackInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}...)
}
func TestEntryExists(t *testing.T) {
@@ -92,157 +96,7 @@ func TestEntryExists(t *testing.T) {
}
}
func TestRoutingWithTables(t *testing.T) {
testCases := []struct {
name string
destination string
captureInterface string
dialer *net.Dialer
packetExpectation PacketExpectation
}{
{
name: "To external host without fwmark via vpn",
destination: "192.0.2.1:53",
captureInterface: "wgtest0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with fwmark via physical interface",
destination: "192.0.2.1:53",
captureInterface: "dummyext0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with fwmark via physical interface",
destination: "10.0.0.1:53",
captureInterface: "dummyint0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
},
{
name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence
destination: "10.0.0.1:53",
captureInterface: "dummyint0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
},
{
name: "To unique vpn route with fwmark via physical interface",
destination: "172.16.0.1:53",
captureInterface: "dummyext0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53),
},
{
name: "To unique vpn route without fwmark via vpn",
destination: "172.16.0.1:53",
captureInterface: "wgtest0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53),
},
{
name: "To more specific route without fwmark via vpn interface",
destination: "10.10.0.1:53",
captureInterface: "dummyint0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53),
},
{
name: "To more specific route (local) without fwmark via physical interface",
destination: "127.0.10.1:53",
captureInterface: "lo",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
wgIface, _, _ := setupTestEnv(t)
// default route exists in main table and vpn table
err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 10.0.0.0/8 route exists in main table and vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 10.10.0.0/24 more specific route exists in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 127.0.10.0/24 more specific route exists in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// unique route in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
filter := createBPFFilter(tc.destination)
handle := startPacketCapture(t, tc.captureInterface, filter)
sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer)
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
packet, err := packetSource.NextPacket()
require.NoError(t, err)
verifyPacket(t, packet, tc.packetExpectation)
})
}
}
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
t.Helper()
ipLayer := packet.Layer(layers.LayerTypeIPv4)
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
ip, ok := ipLayer.(*layers.IPv4)
require.True(t, ok, "Failed to cast to IPv4 layer")
// Convert both source and destination IP addresses to 16-byte representation
expectedSrcIP := exp.SrcIP.To16()
actualSrcIP := ip.SrcIP.To16()
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
expectedDstIP := exp.DstIP.To16()
actualDstIP := ip.DstIP.To16()
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
if exp.UDP {
udpLayer := packet.Layer(layers.LayerTypeUDP)
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
udp, ok := udpLayer.(*layers.UDP)
require.True(t, ok, "Failed to cast to UDP layer")
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
}
if exp.TCP {
tcpLayer := packet.Layer(layers.LayerTypeTCP)
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
tcp, ok := tcpLayer.(*layers.TCP)
require.True(t, ok, "Failed to cast to TCP layer")
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
}
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy {
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
@@ -264,35 +118,52 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
require.NoError(t, err)
}
return dummy
t.Cleanup(func() {
err := netlink.LinkDel(dummy)
assert.NoError(t, err)
})
return dummy.Name
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
t.Helper()
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
require.NoError(t, err)
// Handle existing routes with metric 0
var originalNexthop net.IP
var originalLinkIndex int
if dstIPNet.String() == "0.0.0.0/0" {
gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil {
var err error
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
t.Logf("Failed to fetch original gateway: %v", err)
}
// Handle existing routes with metric 0
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
if err == nil {
t.Cleanup(func() {
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0})
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
})
} else if !errors.Is(err, syscall.ESRCH) {
t.Logf("Failed to delete route: %v", err)
if originalNexthop != nil {
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
switch {
case err != nil && !errors.Is(err, syscall.ESRCH):
t.Logf("Failed to delete route: %v", err)
case err == nil:
t.Cleanup(func() {
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0})
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
})
default:
t.Logf("Failed to delete route: %v", err)
}
}
}
link, err := netlink.LinkByName(intf)
require.NoError(t, err)
linkIndex := link.Attrs().Index
route := &netlink.Route{
Dst: dstIPNet,
Gw: gw,
@@ -307,9 +178,9 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
require.NoError(t, err)
}
// fetchOriginalGateway returns the original gateway IP address and the interface index.
func fetchOriginalGateway(family int) (net.IP, int, error) {
routes, err := netlink.RouteList(nil, family)
if err != nil {
@@ -317,153 +188,20 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
}
for _, route := range routes {
if route.Dst == nil {
if route.Dst == nil && route.Priority == 0 {
return route.Gw, route.LinkIndex, nil
}
}
return nil, 0, fmt.Errorf("default route not found")
return nil, 0, ErrRouteNotFound
}
func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) {
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index)
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index)
t.Cleanup(func() {
err := netlink.LinkDel(defaultDummy)
assert.NoError(t, err)
err = netlink.LinkDel(otherDummy)
assert.NoError(t, err)
})
return defaultDummy.Name, otherDummy.Name
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet(nil)
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
require.NoError(t, err, "should create testing WireGuard interface")
t.Cleanup(func() {
wgInterface.Close()
})
return wgInterface
}
func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) {
t.Helper()
defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgIface.Close())
})
err := setupRouting()
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
return wgIface, defaultDummy, otherDummy
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()
inactive, err := pcap.NewInactiveHandle(intf)
require.NoError(t, err, "Failed to create inactive pcap handle")
defer inactive.CleanUp()
err = inactive.SetSnapLen(1600)
require.NoError(t, err, "Failed to set snap length on inactive handle")
err = inactive.SetTimeout(time.Second * 10)
require.NoError(t, err, "Failed to set timeout on inactive handle")
err = inactive.SetImmediateMode(true)
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
handle, err := inactive.Activate()
require.NoError(t, err, "Failed to activate pcap handle")
t.Cleanup(handle.Close)
err = handle.SetBPFFilter(filter)
require.NoError(t, err, "Failed to set BPF filter")
return handle
}
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) {
t.Helper()
if dialer == nil {
dialer = &net.Dialer{}
}
if sourcePort != 0 {
localUDPAddr := &net.UDPAddr{
IP: net.IPv4zero,
Port: sourcePort,
}
dialer.LocalAddr = localUDPAddr
}
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
conn, err := dialer.Dial("udp", destination)
require.NoError(t, err, "Failed to dial UDP")
defer conn.Close()
data, err := msg.Pack()
require.NoError(t, err, "Failed to pack DNS message")
_, err = conn.Write(data)
if err != nil {
if strings.Contains(err.Error(), "required key not available") {
t.Logf("Ignoring WireGuard key error: %v", err)
return
}
t.Fatalf("Failed to send DNS query: %v", err)
}
}
func createBPFFilter(destination string) string {
host, port, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
}
return "udp"
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
}

View File

@@ -1,148 +0,0 @@
//go:build !android
//nolint:unused
package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"os/exec"
"runtime"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
)
var errRouteNotFound = fmt.Errorf("route not found")
func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
defaultGateway, err := getExistingRIBRouteGateway(defaultv4)
if err != nil && !errors.Is(err, errRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
addr := netip.MustParseAddr(defaultGateway.String())
if !prefix.Contains(addr) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(addr, 32)
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
if err != nil && !errors.Is(err, errRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "")
}
func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := genericAddRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return genericAddToRouteTable(prefix, addr, intf)
}
func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
return genericRemoveFromRouteTable(prefix, addr, intf)
}
func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error {
cmd := exec.Command("route", "add", prefix.String(), addr)
out, err := cmd.Output()
if err != nil {
return fmt.Errorf("add route: %w", err)
}
log.Debugf(string(out))
return nil
}
func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error {
args := []string{"delete", prefix.String()}
if runtime.GOOS == "darwin" {
args = append(args, addr)
}
cmd := exec.Command("route", args...)
out, err := cmd.Output()
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
log.Debugf(string(out))
return nil
}
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
r, err := netroute.New()
if err != nil {
return nil, fmt.Errorf("new netroute: %w", err)
}
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
if err != nil {
log.Errorf("Getting routes returned an error: %v", err)
return nil, errRouteNotFound
}
if gateway == nil {
return preferredSrc, nil
}
return gateway, nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}

View File

@@ -1,22 +1,23 @@
//go:build !linux || android
//go:build !linux && !ios
package routemanager
import (
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func setupRouting() error {
return nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(prefix netip.Prefix, intf string) error {
return genericAddVPNRoute(prefix, intf)
}
func removeVPNRoute(prefix netip.Prefix, intf string) error {
return genericRemoveVPNRoute(prefix, intf)
}

View File

@@ -1,80 +0,0 @@
//go:build !linux || android
package routemanager
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
require.NoError(t, setupRouting())
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is4() {
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}

View File

@@ -1,14 +1,14 @@
//go:build !android
//go:build !android && !ios
package routemanager
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"runtime"
"strings"
"testing"
@@ -22,47 +22,9 @@ import (
"github.com/netbirdio/netbird/iface"
)
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" {
outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String())
require.NoError(t, err, "getOutgoingInterfaceLinux should not return error")
if invert {
require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface")
} else {
require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface")
}
return
}
prefixGateway, err := getExistingRIBRouteGateway(prefix)
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
}
}
func getOutgoingInterfaceLinux(destination string) (string, error) {
cmd := exec.Command("ip", "route", "get", destination)
output, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("executing ip route get: %w", err)
}
return parseOutgoingInterface(string(output)), nil
}
func parseOutgoingInterface(routeGetOutput string) string {
fields := strings.Fields(routeGetOutput)
for i, field := range fields {
if field == "dev" && i+1 < len(fields) {
return fields[i+1]
}
}
return ""
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddRemoveRoutes(t *testing.T) {
@@ -99,14 +61,14 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
require.NoError(t, setupRouting())
_, _, err = setupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard {
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
@@ -116,13 +78,13 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
require.NoError(t, err, "getNextHop should not return err")
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err)
if testCase.shouldBeRemoved {
@@ -135,12 +97,12 @@ func TestAddRemoveRoutes(t *testing.T) {
}
}
func TestGetExistingRIBRouteGateway(t *testing.T) {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
func TestGetNextHop(t *testing.T) {
gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if gateway == nil {
if !gateway.IsValid() {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
@@ -162,11 +124,11 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
}
}
localIP, err := getExistingRIBRouteGateway(testingPrefix)
localIP, _, err := getNextHop(testingPrefix.Addr())
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if localIP == nil {
if !localIP.IsValid() {
t.Fatal("should return a gateway for local network")
}
if localIP.String() == gateway.String() {
@@ -177,8 +139,8 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
}
}
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultGateway: ", defaultGateway)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
@@ -238,21 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
require.NoError(t, setupRouting())
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
MockAddr := wgInterface.Address().IP.String()
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name())
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name())
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
@@ -262,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
require.True(t, ok, "route should exist")
// remove route again if added
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name())
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "should not return err")
}
@@ -272,11 +227,176 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
// Linux uses a separate routing table, so the route can exist in both tables.
// The main routing table takes precedence over the wireguard routing table.
if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" {
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is6() {
continue
}
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
continue
}
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
continue
}
addressPrefixes = append(addressPrefixes, p.Masked())
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet()
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
require.NoError(t, err, "should create testing WireGuard interface")
t.Cleanup(func() {
wgInterface.Close()
})
return wgInterface
}
func setupTestEnv(t *testing.T) {
t.Helper()
setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgIface.Close())
})
_, _, err := setupRouting(nil, wgIface)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
// default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixGateway, _, err := getNextHop(prefix.Addr())
require.NoError(t, err, "getNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
}
}

View File

@@ -0,0 +1,234 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package routemanager
import (
"fmt"
"net"
"strings"
"testing"
"time"
"github.com/gopacket/gopacket"
"github.com/gopacket/gopacket/layers"
"github.com/gopacket/gopacket/pcap"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
type testCase struct {
name string
destination string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
destination: "192.0.2.1:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
destination: "172.16.0.2:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
func TestRouting(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)
filter := createBPFFilter(tc.destination)
handle := startPacketCapture(t, tc.expectedInterface, filter)
sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer)
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
packet, err := packetSource.NextPacket()
require.NoError(t, err)
verifyPacket(t, packet, tc.expectedPacket)
})
}
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()
inactive, err := pcap.NewInactiveHandle(intf)
require.NoError(t, err, "Failed to create inactive pcap handle")
defer inactive.CleanUp()
err = inactive.SetSnapLen(1600)
require.NoError(t, err, "Failed to set snap length on inactive handle")
err = inactive.SetTimeout(time.Second * 10)
require.NoError(t, err, "Failed to set timeout on inactive handle")
err = inactive.SetImmediateMode(true)
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
handle, err := inactive.Activate()
require.NoError(t, err, "Failed to activate pcap handle")
t.Cleanup(handle.Close)
err = handle.SetBPFFilter(filter)
require.NoError(t, err, "Failed to set BPF filter")
return handle
}
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) {
t.Helper()
if dialer == nil {
dialer = &net.Dialer{}
}
if sourcePort != 0 {
localUDPAddr := &net.UDPAddr{
IP: net.IPv4zero,
Port: sourcePort,
}
switch dialer := dialer.(type) {
case *nbnet.Dialer:
dialer.LocalAddr = localUDPAddr
case *net.Dialer:
dialer.LocalAddr = localUDPAddr
default:
t.Fatal("Unsupported dialer type")
}
}
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
conn, err := dialer.Dial("udp", destination)
require.NoError(t, err, "Failed to dial UDP")
defer conn.Close()
data, err := msg.Pack()
require.NoError(t, err, "Failed to pack DNS message")
_, err = conn.Write(data)
if err != nil {
if strings.Contains(err.Error(), "required key not available") {
t.Logf("Ignoring WireGuard key error: %v", err)
return
}
t.Fatalf("Failed to send DNS query: %v", err)
}
}
func createBPFFilter(destination string) string {
host, port, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
}
return "udp"
}
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
t.Helper()
ipLayer := packet.Layer(layers.LayerTypeIPv4)
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
ip, ok := ipLayer.(*layers.IPv4)
require.True(t, ok, "Failed to cast to IPv4 layer")
// Convert both source and destination IP addresses to 16-byte representation
expectedSrcIP := exp.SrcIP.To16()
actualSrcIP := ip.SrcIP.To16()
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
expectedDstIP := exp.DstIP.To16()
actualDstIP := ip.DstIP.To16()
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
if exp.UDP {
udpLayer := packet.Layer(layers.LayerTypeUDP)
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
udp, ok := udpLayer.(*layers.UDP)
require.True(t, ok, "Failed to cast to UDP layer")
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
}
if exp.TCP {
tcpLayer := packet.Layer(layers.LayerTypeTCP)
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
tcp, ok := tcpLayer.(*layers.TCP)
require.True(t, ok, "Failed to cast to TCP layer")
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
}
}

View File

@@ -6,9 +6,14 @@ import (
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
type Win32_IP4RouteTable struct {
@@ -16,6 +21,16 @@ type Win32_IP4RouteTable struct {
Mask string
}
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func getRoutesFromTable() ([]netip.Prefix, error) {
var routes []Win32_IP4RouteTable
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
@@ -48,10 +63,85 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
return prefixList, nil
}
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
destinationPrefix := prefix.String()
psCmd := "New-NetRoute"
addressFamily := "IPv4"
if prefix.Addr().Is6() {
addressFamily = "IPv6"
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`,
psCmd, addressFamily, destinationPrefix,
)
if intfIdx != "" {
script = fmt.Sprintf(
`%s -InterfaceIndex %s`, script, intfIdx,
)
} else {
script = fmt.Sprintf(
`%s -InterfaceAlias "%s"`, script, intf,
)
}
if nexthop.IsValid() {
script = fmt.Sprintf(
`%s -NextHop "%s"`, script, nexthop,
)
}
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
log.Tracef("PowerShell %s: %s", script, string(out))
if err != nil {
return fmt.Errorf("PowerShell add route: %w", err)
}
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
}
return nil
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
var intfIdx string
if nexthop.Zone() != "" {
intfIdx = nexthop.Zone()
nexthop.WithZone("")
}
// Powershell doesn't support adding routes without an interface but allows to add interface by name
if intf != "" || intfIdx != "" {
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String())
}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}

View File

@@ -0,0 +1,289 @@
package routemanager
import (
"context"
"encoding/json"
"fmt"
"net"
"os/exec"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
)
var expectedExtInt = "Ethernet1"
type RouteInfo struct {
NextHop string `json:"nexthop"`
InterfaceAlias string `json:"interfacealias"`
RouteMetric int `json:"routemetric"`
}
type FindNetRouteOutput struct {
IPAddress string `json:"IPAddress"`
InterfaceIndex int `json:"InterfaceIndex"`
InterfaceAlias string `json:"InterfaceAlias"`
AddressFamily int `json:"AddressFamily"`
NextHop string `json:"NextHop"`
DestinationPrefix string `json:"DestinationPrefix"`
}
type testCase struct {
name string
destination string
expectedSourceIP string
expectedDestPrefix string
expectedNextHop string
expectedInterface string
dialer dialer
}
var expectedVPNint = "wgtest0"
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
destination: "192.0.2.1:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "128.0.0.0/1",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53",
expectedDestPrefix: "192.0.2.1/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53",
expectedDestPrefix: "10.0.0.2/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "10.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
{
name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53",
expectedDestPrefix: "172.16.0.2/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To unique vpn route without custom dialer via vpn",
destination: "172.16.0.2:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "172.16.0.0/12",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To more specific route without custom dialer via vpn interface",
destination: "10.10.0.2:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "10.10.0.0/24",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "10.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
}
func TestRouting(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)
route, err := fetchOriginalGateway()
require.NoError(t, err, "Failed to fetch original gateway")
ip, err := fetchInterfaceIP(route.InterfaceAlias)
require.NoError(t, err, "Failed to fetch interface IP")
output := testRoute(t, tc.destination, tc.dialer)
if tc.expectedInterface == expectedExtInt {
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
} else {
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
}
})
}
}
// fetchInterfaceIP fetches the IPv4 address of the specified interface.
func fetchInterfaceIP(interfaceAlias string) (string, error) {
script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias)
out, err := exec.Command("powershell", "-Command", script).Output()
if err != nil {
return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err)
}
ip := strings.TrimSpace(string(out))
return ip, nil
}
func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
conn, err := dialer.DialContext(ctx, "udp", destination)
require.NoError(t, err, "Failed to dial destination")
defer func() {
err := conn.Close()
assert.NoError(t, err, "Failed to close connection")
}()
host, _, err := net.SplitHostPort(destination)
require.NoError(t, err)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
out, err := exec.Command("powershell", "-Command", script).Output()
require.NoError(t, err, "Failed to execute Find-NetRoute")
var outputs []FindNetRouteOutput
err = json.Unmarshal(out, &outputs)
require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute")
require.Greater(t, len(outputs), 0, "No route found for destination")
combinedOutput := combineOutputs(outputs)
return combinedOutput
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
require.NoError(t, err)
subnetMaskSize, _ := ipNet.Mask.Size()
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
// Wait for the IP address to be applied
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = waitForIPAddress(ctx, interfaceName, ip.String())
require.NoError(t, err, "IP address not applied within timeout")
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
})
return interfaceName
}
func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
}
var routeInfo RouteInfo
err = json.Unmarshal(output, &routeInfo)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON output: %w", err)
}
return &routeInfo, nil
}
func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) {
t.Helper()
assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch")
assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch")
assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch")
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
}
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
if err != nil {
return err
}
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
for _, ip := range ipAddresses {
if strings.TrimSpace(ip) == expectedIPAddress {
return nil
}
}
}
}
}
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
var combined FindNetRouteOutput
for _, output := range outputs {
if output.IPAddress != "" {
combined.IPAddress = output.IPAddress
}
if output.InterfaceIndex != 0 {
combined.InterfaceIndex = output.InterfaceIndex
}
if output.InterfaceAlias != "" {
combined.InterfaceAlias = output.InterfaceAlias
}
if output.AddressFamily != 0 {
combined.AddressFamily = output.AddressFamily
}
if output.NextHop != "" {
combined.NextHop = output.NextHop
}
if output.DestinationPrefix != "" {
combined.DestinationPrefix = output.DestinationPrefix
}
}
return &combined
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
}

View File

@@ -1,10 +1,8 @@
package wgproxy
import (
"context"
"fmt"
nbnet "github.com/netbirdio/netbird/util/net"
"net"
)
const (
@@ -25,7 +23,7 @@ func (pl portLookup) searchFreePort() (int, error) {
}
func (pl portLookup) tryToBind(port int) error {
l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port))
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
if err != nil {
return err
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf"
@@ -29,7 +30,7 @@ type WGEBPFProxy struct {
turnConnMutex sync.Mutex
rawConn net.PacketConn
conn *net.UDPConn
conn transport.UDPConn
}
// NewWGEBPFProxy create new WGEBPFProxy instance
@@ -67,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error {
IP: net.ParseIP("127.0.0.1"),
}
p.conn, err = nbnet.ListenUDP("udp", &addr)
conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil {
cErr := p.Free()
if cErr != nil {
@@ -75,6 +76,7 @@ func (p *WGEBPFProxy) Listen() error {
}
return err
}
p.conn = conn
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", wgPorxyPort)

View File

@@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Routes: maps.Keys(peerState.Routes),
Routes: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)

View File

@@ -25,8 +25,6 @@ func Detect(ctx context.Context) string {
detectDigitalOcean,
detectGCP,
detectOracle,
detectIBMCloud,
detectSoftlayer,
detectVultr,
}

View File

@@ -6,7 +6,7 @@ import (
)
func detectGCP(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://metadata.google.internal", nil)
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil)
if err != nil {
return ""
}

View File

@@ -1,54 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectIBMCloud(ctx context.Context) string {
v1ResultChan := make(chan bool, 1)
v2ResultChan := make(chan bool, 1)
go func() {
v1ResultChan <- detectIBMSecure(ctx)
}()
go func() {
v2ResultChan <- detectIBM(ctx)
}()
v1Result, v2Result := <-v1ResultChan, <-v2ResultChan
if v1Result || v2Result {
return "IBM Cloud"
}
return ""
}
func detectIBMSecure(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "PUT", "https://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
if err != nil {
return false
}
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
func detectIBM(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "PUT", "http://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
if err != nil {
return false
}
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}

View File

@@ -1,25 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectSoftlayer(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.service.softlayer.com/rest/v3/SoftLayer_Resource_Metadata/UserMetadata.txt", nil)
if err != nil {
return ""
}
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
// Since SoftLayer was acquired by IBM, we should return "IBM Cloud"
return "IBM Cloud"
}
return ""
}

4
go.mod
View File

@@ -53,7 +53,7 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
github.com/libp2p/go-netroute v0.2.0
github.com/libp2p/go-netroute v0.2.1
github.com/magiconair/properties v1.8.5
github.com/mattn/go-sqlite3 v1.14.19
github.com/mdlayher/socket v0.4.1
@@ -172,7 +172,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2023
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

10
go.sum
View File

@@ -345,8 +345,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE=
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU=
github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ=
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls=
github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
@@ -389,8 +389,8 @@ github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5 h1:m48qfB2ILlFx3oZlw7aEeD+V6vXnMb0hNwmDCtdcgv0=
github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -659,7 +659,6 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
@@ -746,7 +745,6 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -13,14 +13,6 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
type ICEBind struct {
*wgConn.StdNetBind
@@ -28,6 +20,8 @@ type ICEBind struct {
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
receiverCreator *receiverCreator
}
func NewICEBind(transportNet transport.Net) *ICEBind {
@@ -35,9 +29,9 @@ func NewICEBind(transportNet transport.Net) *ICEBind {
transportNet: transportNet,
}
rc := receiverCreator{
ib,
}
rc := newReceiverCreator(ib)
ib.receiverCreator = rc
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
@@ -53,16 +47,22 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
func (s *ICEBind) SetTurnConn(conn interface{}) {
s.receiverCreator.setTurnConn(conn)
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn, netConn net.PacketConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
},
)
if conn != nil {
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
},
)
}
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
@@ -71,17 +71,40 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
if netConn != nil {
log.Debugf("----read from turn conn...")
msg := &(*msgs)[0]
msg.N, msg.Addr, err = netConn.ReadFrom(msg.Buffers[0])
if err != nil {
return 0, err
}
log.Debugf("----msg address is: %s, size: %d", msg.Addr.String(), msg.N)
numMsgs = 1
} else {
log.Debugf("----read from pc...")
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
if netConn != nil {
log.Debugf("----read from turn conn...")
msg := &(*msgs)[0]
msg.N, msg.Addr, err = netConn.ReadFrom(msg.Buffers[0])
if err != nil {
return 0, err
}
log.Debugf("----msg address is: %s, size: %d", msg.Addr.String(), msg.N)
numMsgs = 1
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
@@ -95,7 +118,10 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
ep := &wgConn.StdNetEndpoint{
AddrPort: addrPort,
Conn: netConn,
}
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}

View File

@@ -0,0 +1,38 @@
package bind
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
relayConn net.PacketConn
}
func newReceiverCreator(iceBind *ICEBind) *receiverCreator {
return &receiverCreator{
iceBind: iceBind,
}
}
func (rc *receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn, nil)
}
func (rc *receiverCreator) CreateRelayReceiverFn(msgPool *sync.Pool) wgConn.ReceiveFunc {
if rc.relayConn == nil {
log.Debugf("-------rc.conn is nil")
return nil
}
return rc.iceBind.createIPv4ReceiverFn(msgPool, nil, nil, rc.relayConn)
}
func (rc *receiverCreator) setTurnConn(relayConn interface{}) {
log.Debug("------ SET TURN CONN")
rc.relayConn = relayConn.(net.PacketConn)
}

View File

@@ -150,3 +150,10 @@ func (w *WGIface) GetDevice() *DeviceWrapper {
func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
return w.configurer.getStats(peerKey)
}
func (w *WGIface) SetRelayConn(conn interface{}) {
w.mu.Lock()
defer w.mu.Unlock()
w.tun.SetTurnConn(conn)
}

View File

@@ -85,23 +85,27 @@ func tunModuleIsLoaded() bool {
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
return false
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
/*
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if canCreateFakeWireGuardInterface() {
return true
}
if canCreateFakeWireGuardInterface() {
return true
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
return loaded
return loaded
*/
}
func canCreateFakeWireGuardInterface() bool {

View File

@@ -15,4 +15,5 @@ type wgTunDevice interface {
DeviceName() string
Close() error
Wrapper() *DeviceWrapper // todo eliminate this function
SetTurnConn(conn interface{})
}

View File

@@ -28,6 +28,14 @@ type tunDevice struct {
configurer wgConfigurer
}
func (t *tunDevice) SetTurnConn(conn interface{}) {
t.iceBind.SetTurnConn(conn)
err := t.device.BindUpdate()
if err != nil {
log.Errorf("failed to update bind: %v", err)
}
}
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice {
return &tunDevice{
name: name,

View File

@@ -31,6 +31,11 @@ type tunKernelDevice struct {
udpMux *bind.UniversalUDPMuxDefault
}
func (t *tunKernelDevice) SetTurnConn(interface{}) {
//TODO implement me
panic("implement me")
}
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
ctx, cancel := context.WithCancel(context.Background())
return &tunKernelDevice{

View File

@@ -30,6 +30,11 @@ type tunNetstackDevice struct {
configurer wgConfigurer
}
func (t *tunNetstackDevice) SetTurnConn(interface{}) {
//TODO implement me
panic("implement me")
}
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice {
return &tunNetstackDevice{
name: name,

View File

@@ -54,7 +54,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(device.LogLevelError, "[netbird] "),
)
err = t.assignAddr()
@@ -70,6 +70,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.configurer.close()
return nil, err
}
log.Debugf("configuration done")
return t.configurer, nil
}
@@ -125,6 +126,14 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
return t.wrapper
}
func (t *tunUSPDevice) SetTurnConn(conn interface{}) {
t.iceBind.SetTurnConn(conn)
err := t.device.BindUpdate()
if err != nil {
log.Errorf("failed to update bind: %v", err)
}
}
// assignAddr Adds IP address to the tunnel interface
func (t *tunUSPDevice) assignAddr() error {
link := newWGLink(t.name)

View File

@@ -58,6 +58,7 @@ services:
command: [
"--port", "443",
"--log-file", "console",
"--log-level", "info",
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"

View File

@@ -2,7 +2,7 @@ version: "3"
services:
#UI dashboard
dashboard:
image: wiretrustee/dashboard:$NETBIRD_DASHBOARD_TAG
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
restart: unless-stopped
#ports:
# - 80:80

View File

@@ -242,19 +242,19 @@ type UserPermissions struct {
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"`
Permissions UserPermissions `json:"permissions"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -278,7 +278,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
return routes
}
// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
}
if user.Id != account.CreatedBy {
if user.Role != UserRoleOwner {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
for _, otherUser := range account.Users {

View File

@@ -115,7 +115,15 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("client_secret", ac.clientConfig.ClientSecret)
data.Set("grant_type", ac.clientConfig.GrantType)
data.Set("scope", "https://graph.microsoft.com/.default")
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
if err != nil {
return nil, err
}
// get base url and add "/.default" as scope
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
scopeURL := baseURL + "/.default"
data.Set("scope", scopeURL)
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)

View File

@@ -273,7 +273,7 @@ func (om *OktaManager) DeleteUser(userID string) error {
return nil
}
// parseOktaUserToUserData parse okta user to UserData.
// parseOktaUser parse okta user to UserData.
func parseOktaUser(user *okta.User) (*UserData, error) {
var oktaUser struct {
Email string `json:"email"`

View File

@@ -706,7 +706,7 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
return nil
}
// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
if am.UpdateIntegratedValidatorGroupsFunc != nil {
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)

View File

@@ -551,8 +551,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if requiresApproval {
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
@@ -563,11 +563,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
am.updateAccountPeers(account)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
}
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
}
// LoginPeer logs in or registers a peer.

View File

@@ -56,7 +56,7 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
// MarshalCredential marshal a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type,
rosenpassPubKey []byte, rosenpassAddr string) (*proto.Message, error) {
rosenpassPubKey []byte, rosenpassAddr, relayedAddress, serverRefIP string) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey.String(),
@@ -69,6 +69,10 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, cre
RosenpassPubKey: rosenpassPubKey,
RosenpassServerAddr: rosenpassAddr,
},
Relay: &proto.Relay{
RelayedAddress: relayedAddress,
SrvRefAddress: serverRefIP,
},
},
}, nil
}

View File

@@ -215,16 +215,21 @@ type Body struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
// these will be set in OFFER, ANSWER, CANDIDATE only
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
// wgListenPort is an actual WireGuard listen port
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
// these will be set in OFFER, ANSWER, CANDIDATE only
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
// these will be set in OFFER, ANSWER, CANDIDATE only
NetBirdVersion string `protobuf:"bytes,4,opt,name=netBirdVersion,proto3" json:"netBirdVersion,omitempty"`
Mode *Mode `protobuf:"bytes,5,opt,name=mode,proto3" json:"mode,omitempty"`
// featuresSupported list of supported features by the client of this protocol
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
// is this optional or mandatory?
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
Relay *Relay `protobuf:"bytes,8,opt,name=relay,proto3" json:"relay,omitempty"`
}
func (x *Body) Reset() {
@@ -308,13 +313,18 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig {
return nil
}
func (x *Body) GetRelay() *Relay {
if x != nil {
return x.Relay
}
return nil
}
// Mode indicates a connection mode
type Mode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Direct *bool `protobuf:"varint,1,opt,name=direct,proto3,oneof" json:"direct,omitempty"`
}
func (x *Mode) Reset() {
@@ -349,11 +359,59 @@ func (*Mode) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{3}
}
func (x *Mode) GetDirect() bool {
if x != nil && x.Direct != nil {
return *x.Direct
type Relay struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
RelayedAddress string `protobuf:"bytes,1,opt,name=relayedAddress,proto3" json:"relayedAddress,omitempty"`
SrvRefAddress string `protobuf:"bytes,2,opt,name=srvRefAddress,proto3" json:"srvRefAddress,omitempty"`
}
func (x *Relay) Reset() {
*x = Relay{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
return false
}
func (x *Relay) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Relay) ProtoMessage() {}
func (x *Relay) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Relay.ProtoReflect.Descriptor instead.
func (*Relay) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{4}
}
func (x *Relay) GetRelayedAddress() string {
if x != nil {
return x.RelayedAddress
}
return ""
}
func (x *Relay) GetSrvRefAddress() string {
if x != nil {
return x.SrvRefAddress
}
return ""
}
type RosenpassConfig struct {
@@ -369,7 +427,7 @@ type RosenpassConfig struct {
func (x *RosenpassConfig) Reset() {
*x = RosenpassConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[4]
mi := &file_signalexchange_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -382,7 +440,7 @@ func (x *RosenpassConfig) String() string {
func (*RosenpassConfig) ProtoMessage() {}
func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[4]
mi := &file_signalexchange_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -395,7 +453,7 @@ func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use RosenpassConfig.ProtoReflect.Descriptor instead.
func (*RosenpassConfig) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{4}
return file_signalexchange_proto_rawDescGZIP(), []int{5}
}
func (x *RosenpassConfig) GetRosenpassPubKey() []byte {
@@ -431,7 +489,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xf6, 0x02, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -451,33 +509,39 @@ var file_signalexchange_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e,
0x0a, 0x04, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
0x88, 0x01, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d,
0x0a, 0x0f, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69,
0x67, 0x12, 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75,
0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65,
0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64,
0x64, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01,
0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67,
0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72,
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59,
0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12,
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2b, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18,
0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x52, 0x05, 0x72, 0x65,
0x6c, 0x61, 0x79, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f,
0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52,
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10,
0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x06, 0x0a, 0x04, 0x4d,
0x6f, 0x64, 0x65, 0x22, 0x55, 0x0a, 0x05, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x26, 0x0a, 0x0e,
0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x41, 0x64, 0x64,
0x72, 0x65, 0x73, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x72, 0x76, 0x52, 0x65, 0x66, 0x41, 0x64,
0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x73, 0x72, 0x76,
0x52, 0x65, 0x66, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x6d, 0x0a, 0x0f, 0x52, 0x6f,
0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x0a,
0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e,
0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69,
0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04,
0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f,
0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69,
0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63,
0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e,
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -493,29 +557,31 @@ func file_signalexchange_proto_rawDescGZIP() []byte {
}
var file_signalexchange_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_signalexchange_proto_goTypes = []interface{}{
(Body_Type)(0), // 0: signalexchange.Body.Type
(*EncryptedMessage)(nil), // 1: signalexchange.EncryptedMessage
(*Message)(nil), // 2: signalexchange.Message
(*Body)(nil), // 3: signalexchange.Body
(*Mode)(nil), // 4: signalexchange.Mode
(*RosenpassConfig)(nil), // 5: signalexchange.RosenpassConfig
(*Relay)(nil), // 5: signalexchange.Relay
(*RosenpassConfig)(nil), // 6: signalexchange.RosenpassConfig
}
var file_signalexchange_proto_depIdxs = []int32{
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
0, // 1: signalexchange.Body.type:type_name -> signalexchange.Body.Type
4, // 2: signalexchange.Body.mode:type_name -> signalexchange.Mode
5, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
1, // 4: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
1, // 5: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
6, // [6:8] is the sub-list for method output_type
4, // [4:6] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
6, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
5, // 4: signalexchange.Body.relay:type_name -> signalexchange.Relay
1, // 5: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
1, // 8: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
7, // [7:9] is the sub-list for method output_type
5, // [5:7] is the sub-list for method input_type
5, // [5:5] is the sub-list for extension type_name
5, // [5:5] is the sub-list for extension extendee
0, // [0:5] is the sub-list for field type_name
}
func init() { file_signalexchange_proto_init() }
@@ -573,6 +639,18 @@ func file_signalexchange_proto_init() {
}
}
file_signalexchange_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Relay); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_signalexchange_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RosenpassConfig); i {
case 0:
return &v.state
@@ -585,14 +663,13 @@ func file_signalexchange_proto_init() {
}
}
}
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_signalexchange_proto_rawDesc,
NumEnums: 1,
NumMessages: 5,
NumMessages: 6,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -49,22 +49,33 @@ message Body {
MODE = 4;
}
Type type = 1;
// these will be set in OFFER, ANSWER, CANDIDATE only
string payload = 2;
// wgListenPort is an actual WireGuard listen port
// these will be set in OFFER, ANSWER, CANDIDATE only
uint32 wgListenPort = 3;
// these will be set in OFFER, ANSWER, CANDIDATE only
string netBirdVersion = 4;
Mode mode = 5;
// featuresSupported list of supported features by the client of this protocol
repeated uint32 featuresSupported = 6;
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
// is this optional or mandatory?
RosenpassConfig rosenpassConfig = 7;
Relay relay = 8;
}
// Mode indicates a connection mode
message Mode {
optional bool direct = 1;
}
message Relay {
string relayedAddress = 1;
string srvRefAddress = 2;
}
message RosenpassConfig {

38
util/grpc/dialer.go Normal file
View File

@@ -0,0 +1,38 @@
package grpc
import (
"context"
"net"
"os/user"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
log.Fatalf("failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, err
}
return conn, nil
})
}

View File

@@ -1,9 +0,0 @@
//go:build !linux || android
package grpc
import "google.golang.org/grpc"
func WithCustomDialer() grpc.DialOption {
return grpc.EmptyDialOption{}
}

View File

@@ -1,18 +0,0 @@
//go:build !android
package grpc
import (
"context"
"net"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return nbnet.NewDialer().DialContext(ctx, "tcp", addr)
})
}

21
util/net/dialer.go Normal file
View File

@@ -0,0 +1,21 @@
package net
import (
"net"
)
// Dialer extends the standard net.Dialer with the ability to execute hooks before
// and after connections. This can be used to bypass the VPN for connections using this dialer.
type Dialer struct {
*net.Dialer
}
// NewDialer returns a customized net.Dialer with overridden Control method
func NewDialer() *Dialer {
dialer := &Dialer{
Dialer: &net.Dialer{},
}
dialer.init()
return dialer
}

View File

@@ -1,19 +1,163 @@
//go:build !linux || android
//go:build !android && !ios
package net
import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
)
func NewDialer() *net.Dialer {
return &net.Dialer{}
type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
var (
dialerDialHooksMutex sync.RWMutex
dialerDialHooks []DialerDialHookFunc
dialerCloseHooksMutex sync.RWMutex
dialerCloseHooks []DialerCloseHookFunc
)
// AddDialerHook allows adding a new hook to be executed before dialing.
func AddDialerHook(hook DialerDialHookFunc) {
dialerDialHooksMutex.Lock()
defer dialerDialHooksMutex.Unlock()
dialerDialHooks = append(dialerDialHooks, hook)
}
// AddDialerCloseHook allows adding a new hook to be executed on connection close.
func AddDialerCloseHook(hook DialerCloseHookFunc) {
dialerCloseHooksMutex.Lock()
defer dialerCloseHooksMutex.Unlock()
dialerCloseHooks = append(dialerCloseHooks, hook)
}
// RemoveDialerHook removes all dialer hooks.
func RemoveDialerHooks() {
dialerDialHooksMutex.Lock()
defer dialerDialHooksMutex.Unlock()
dialerDialHooks = nil
dialerCloseHooksMutex.Lock()
defer dialerCloseHooksMutex.Unlock()
dialerCloseHooks = nil
}
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
var resolver *net.Resolver
if d.Resolver != nil {
resolver = d.Resolver
}
connID := GenerateConnID()
if dialerDialHooks != nil {
if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
log.Errorf("Failed to call dialer hooks: %v", err)
}
}
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("dial: %w", err)
}
// Wrap the connection in Conn to handle Close with hooks
return &Conn{Conn: conn, ID: connID}, nil
}
// Dial wraps the net.Dialer's Dial method to use the custom connection
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
func (c *Conn) Close() error {
err := c.Conn.Close()
dialerCloseHooksMutex.RLock()
defer dialerCloseHooksMutex.RUnlock()
for _, hook := range dialerCloseHooks {
if err := hook(c.ID, &c.Conn); err != nil {
log.Errorf("Error executing dialer close hook: %v", err)
}
}
return err
}
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("split host and port: %w", err)
}
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
var result *multierror.Error
dialerDialHooksMutex.RLock()
defer dialerDialHooksMutex.RUnlock()
for _, hook := range dialerDialHooks {
if err := hook(ctx, connID, ips); err != nil {
result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
}
}
return result.ErrorOrNil()
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}
return tcpConn, nil
}

View File

@@ -2,59 +2,11 @@
package net
import (
"context"
"fmt"
"net"
"syscall"
import "syscall"
log "github.com/sirupsen/logrus"
)
func NewDialer() *net.Dialer {
return &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return SetRawSocketMark(c)
},
// init configures the net.Dialer Control function to set the fwmark on the socket
func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
return SetRawSocketMark(c)
}
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type")
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type")
}
return tcpConn, nil
}

15
util/net/dialer_mobile.go Normal file
View File

@@ -0,0 +1,15 @@
//go:build android || ios
package net
import (
"net"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
}

View File

@@ -0,0 +1,6 @@
//go:build !linux || android
package net
func (d *Dialer) init() {
}

21
util/net/listener.go Normal file
View File

@@ -0,0 +1,21 @@
package net
import (
"net"
)
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
type ListenerConfig struct {
*net.ListenConfig
}
// NewListener creates a new ListenerConfig instance.
func NewListener() *ListenerConfig {
listener := &ListenerConfig{
ListenConfig: &net.ListenConfig{},
}
listener.init()
return listener
}

View File

@@ -1,13 +1,163 @@
//go:build !linux || android
//go:build !android && !ios
package net
import "net"
import (
"context"
"fmt"
"net"
"sync"
func NewListener() *net.ListenConfig {
return &net.ListenConfig{}
log "github.com/sirupsen/logrus"
)
// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn.
type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
var (
listenerWriteHooksMutex sync.RWMutex
listenerWriteHooks []ListenerWriteHookFunc
listenerCloseHooksMutex sync.RWMutex
listenerCloseHooks []ListenerCloseHookFunc
)
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
func AddListenerWriteHook(hook ListenerWriteHookFunc) {
listenerWriteHooksMutex.Lock()
defer listenerWriteHooksMutex.Unlock()
listenerWriteHooks = append(listenerWriteHooks, hook)
}
func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, locAddr)
// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection.
func AddListenerCloseHook(hook ListenerCloseHookFunc) {
listenerCloseHooksMutex.Lock()
defer listenerCloseHooksMutex.Unlock()
listenerCloseHooks = append(listenerCloseHooks, hook)
}
// RemoveListenerHooks removes all dialer hooks.
func RemoveListenerHooks() {
listenerWriteHooksMutex.Lock()
defer listenerWriteHooksMutex.Unlock()
listenerWriteHooks = nil
listenerCloseHooksMutex.Lock()
defer listenerCloseHooksMutex.Unlock()
listenerCloseHooks = nil
}
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
}
connID := GenerateConnID()
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
}
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
type PacketConn struct {
net.PacketConn
ID ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
callWriteHooks(c.ID, c.seenAddrs, b, addr)
return c.PacketConn.WriteTo(b, addr)
}
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
func (c *PacketConn) Close() error {
c.seenAddrs = &sync.Map{}
return closeConn(c.ID, c.PacketConn)
}
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
type UDPConn struct {
*net.UDPConn
ID ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
callWriteHooks(c.ID, c.seenAddrs, b, addr)
return c.UDPConn.WriteTo(b, addr)
}
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
func (c *UDPConn) Close() error {
c.seenAddrs = &sync.Map{}
return closeConn(c.ID, c.UDPConn)
}
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
ipStr, _, splitErr := net.SplitHostPort(addr.String())
if splitErr != nil {
log.Errorf("Error splitting IP address and port: %v", splitErr)
return
}
ip, err := net.ResolveIPAddr("ip", ipStr)
if err != nil {
log.Errorf("Error resolving IP address: %v", err)
return
}
log.Debugf("Listener resolved IP for %s: %s", addr, ip)
func() {
listenerWriteHooksMutex.RLock()
defer listenerWriteHooksMutex.RUnlock()
for _, hook := range listenerWriteHooks {
if err := hook(id, ip, b); err != nil {
log.Errorf("Error executing listener write hook: %v", err)
}
}
}()
}
}
func closeConn(id ConnectionID, conn net.PacketConn) error {
err := conn.Close()
listenerCloseHooksMutex.RLock()
defer listenerCloseHooksMutex.RUnlock()
for _, hook := range listenerCloseHooks {
if err := hook(id, conn); err != nil {
log.Errorf("Error executing listener close hook: %v", err)
}
}
return err
}
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
packetConn := conn.(*PacketConn)
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
if !ok {
if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
}
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
}

View File

@@ -3,28 +3,12 @@
package net
import (
"context"
"fmt"
"net"
"syscall"
)
func NewListener() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return SetRawSocketMark(c)
},
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
return SetRawSocketMark(c)
}
}
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err)
}
udpConn, ok := pc.(*net.UDPConn)
if !ok {
return nil, fmt.Errorf("packetConn is not a *net.UDPConn")
}
return udpConn, nil
}

View File

@@ -0,0 +1,11 @@
//go:build android || ios
package net
import (
"net"
)
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, laddr)
}

View File

@@ -0,0 +1,6 @@
//go:build !linux || android
package net
func (l *ListenerConfig) init() {
}

View File

@@ -1,6 +1,17 @@
package net
import "github.com/google/uuid"
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
)
// ConnectionID provides a globally unique identifier for network connections.
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
type ConnectionID string
// GenerateConnID generates a unique identifier for each connection.
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}