Compare commits

..

1 Commits

Author SHA1 Message Date
Givi Khojanashvili
2ecba370a3 Rollback simple ACL rules. 2023-04-11 21:32:06 +04:00
83 changed files with 775 additions and 3267 deletions

View File

@@ -6,10 +6,6 @@ on:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: macos-latest

View File

@@ -6,10 +6,6 @@ on:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
strategy:
@@ -70,7 +66,7 @@ jobs:
run: go mod tidy
- name: Generate Iface Test bin
run: go test -c -o iface-testing.bin ./iface/
run: go test -c -o iface-testing.bin ./iface/...
- name: Generate RouteManager Test bin
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...

View File

@@ -6,45 +6,47 @@ on:
- main
pull_request:
env:
downloadPath: '${{ github.workspace }}\temp'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
pre:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- run: bash -x wireguard_nt.sh
working-directory: client
- uses: actions/upload-artifact@v2
with:
name: syso
path: client/*.syso
retention-days: 1
test:
needs: pre
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v2
- name: Install Go
uses: actions/setup-go@v4
id: go
uses: actions/setup-go@v2
with:
go-version: 1.19.x
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
- uses: actions/cache@v2
with:
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
path: |
%LocalAppData%\go-build
~\go\pkg\mod
~\AppData\Local\go-build
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Decompressing wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- uses: actions/download-artifact@v2
with:
name: syso
path: iface\
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
- run: choco install -y sysinternals
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt
- name: Test
run: go test -tags=load_wgnt_from_rsrc -timeout 5m -p 1 ./...

View File

@@ -1,8 +1,5 @@
name: golangci-lint
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
golangci:
name: lint

View File

@@ -7,9 +7,7 @@ on:
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: macos-latest

View File

@@ -7,9 +7,7 @@ on:
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: ubuntu-latest

View File

@@ -9,13 +9,9 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.6"
SIGN_PIPE_VER: "v0.0.5"
GORELEASER_VER: "v1.14.1"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
release:
runs-on: ubuntu-latest
@@ -25,6 +21,10 @@ jobs:
uses: actions/checkout@v2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Generate syso with DLL
run: bash -x wireguard_nt.sh
working-directory: client
-
name: Set up Go
uses: actions/setup-go@v2
@@ -59,17 +59,6 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc amd64
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2

View File

@@ -6,10 +6,6 @@ on:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-latest

View File

@@ -26,7 +26,7 @@ type TunAdapter interface {
// IFaceDiscover export internal IFaceDiscover for mobile
type IFaceDiscover interface {
stdnet.ExternalIFaceDiscover
stdnet.IFaceDiscover
}
func init() {

View File

@@ -193,7 +193,6 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe`
Sleep 3000
Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
RmDir /r "$INSTDIR"
SetShellVarContext current

View File

@@ -27,7 +27,7 @@ const (
)
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
"Tailscale", "tailscale", "docker", "veth", "br-"}
// ConfigInput carries configuration changes to the client
type ConfigInput struct {

View File

@@ -23,7 +23,7 @@ import (
)
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) error {
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@@ -108,7 +108,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: iface.WireGuardModuleIsLoaded(),
KernelInterface: iface.WireguardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
@@ -144,19 +144,13 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover)
if err != nil {
log.Error(err)
return wrapErr(err)
}
md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
if err != nil {
log.Error(err)
return wrapErr(err)
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, statusRecorder)
err = engine.Start()
if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
@@ -200,10 +194,13 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*EngineConfig, error) {
engineConf := &EngineConfig{
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,

View File

@@ -9,10 +9,7 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
)
@@ -202,11 +199,7 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet)
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
@@ -46,6 +47,10 @@ var ErrResetConnection = fmt.Errorf("reset connection")
type EngineConfig struct {
WgPort int
WgIfaceName string
// TunAdapter is option. It is necessary for mobile version.
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.IFaceDiscover
// WgAddr is a Wireguard local address (Netbird Network IP)
WgAddr string
@@ -85,9 +90,7 @@ type Engine struct {
// syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex
config *EngineConfig
mobileDep MobileDependency
config *EngineConfig
// STUNs is a list of STUN servers used by ICE
STUNs []*ice.URL
// TURNs is a list of STUN servers used by ICE
@@ -127,7 +130,7 @@ type Peer struct {
func NewEngine(
ctx context.Context, cancel context.CancelFunc,
signalClient signal.Client, mgmClient mgm.Client,
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
config *EngineConfig, statusRecorder *peer.Status,
) *Engine {
return &Engine{
ctx: ctx,
@@ -137,7 +140,6 @@ func NewEngine(
peerConns: make(map[string]*peer.Conn),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*ice.URL{},
TURNs: []*ice.URL{},
networkSerial: 0,
@@ -164,77 +166,68 @@ func (e *Engine) Stop() error {
return nil
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Start creates a new Wireguard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
wgIFaceName := e.config.WgIfaceName
wgIfaceName := e.config.WgIfaceName
wgAddr := e.config.WgAddr
myPrivateKey := e.config.WgPrivateKey
var err error
transportNet, err := e.newStdNet()
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
return err
}
networkName := "udp"
if e.config.DisableIPv6Discovery {
networkName = "udp4"
}
transportNet, err := e.newStdNet()
if err != nil {
log.Warnf("failed to create pion's stdnet: %s", err)
}
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
e.close()
return err
}
udpMuxParams := ice.UDPMuxParams{
UDPConn: e.udpMuxConn,
Net: transportNet,
}
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
e.close()
return err
}
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
err = e.wgInterface.Create()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
e.close()
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
e.close()
return err
}
if e.wgInterface.IsUserspaceBind() {
iceBind := e.wgInterface.GetBind()
udpMux, err := iceBind.GetICEMux()
if err != nil {
e.close()
return err
}
e.udpMux = udpMux.UDPMuxDefault
e.udpMuxSrflx = udpMux
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
} else {
networkName := "udp"
if e.config.DisableIPv6Discovery {
networkName = "udp4"
}
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
e.close()
return err
}
udpMuxParams := ice.UDPMuxParams{
UDPConn: e.udpMuxConn,
Net: transportNet,
}
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
e.close()
return err
}
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
}
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
if e.dnsServer == nil {
@@ -503,7 +496,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
IP: e.config.WgAddr,
PubKey: e.config.WgPrivateKey.PublicKey().String(),
KernelInterface: iface.WireGuardModuleIsLoaded(),
KernelInterface: iface.WireguardModuleIsLoaded(),
FQDN: conf.GetFqdn(),
})
@@ -829,10 +822,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(),
}
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover)
if err != nil {
return nil, err
}
@@ -1014,6 +1006,12 @@ func (e *Engine) close() {
}
}
if e.udpMuxSrflx != nil {
if err := e.udpMuxSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux: %v", err)
}
}
if e.udpMuxConn != nil {
if err := e.udpMuxConn.Close(); err != nil {
log.Debugf("close udp mux connection: %v", err)

View File

@@ -3,9 +3,9 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/pion/transport/v2/stdnet"
)
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.config.IFaceBlackList)
return stdnet.NewNet()
}

View File

@@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
return stdnet.NewNet(e.config.IFaceDiscover)
}

View File

@@ -3,8 +3,6 @@ package internal
import (
"context"
"fmt"
"github.com/netbirdio/netbird/iface/bind"
"github.com/pion/transport/v2/stdnet"
"net"
"net/netip"
"os"
@@ -74,7 +72,7 @@ func TestEngine_SSH(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -208,24 +206,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet)
if err != nil {
t.Fatal(err)
}
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
conn, err := net.ListenUDP("udp4", nil)
if err != nil {
t.Fatal(err)
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
type testCase struct {
name string
@@ -404,7 +390,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -562,12 +548,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
@@ -731,12 +713,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
@@ -1000,7 +978,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil
}
func startSignal() (*grpc.Server, string, error) {

View File

@@ -1,13 +0,0 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
)
// MobileDependency collect all dependencies for mobile platform
type MobileDependency struct {
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
Routes []string
}

View File

@@ -1,29 +0,0 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: ifaceDiscover,
}
err := md.readMap(mgmClient)
return md, err
}
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
routes, err := mgmClient.GetRoutes()
if err != nil {
return err
}
d.Routes = make([]string, len(routes))
for i, r := range routes {
d.Routes[i] = r.GetNetwork()
}
return nil
}

View File

@@ -1,13 +0,0 @@
//go:build !android
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
return MobileDependency{}, nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/pion/ice/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/stdnet"
@@ -45,9 +46,6 @@ type ConnConfig struct {
LocalWgPort int
NATExternalIPs []string
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
UserspaceBind bool
}
// OfferAnswer represents a session establishment offer or answer
@@ -97,7 +95,7 @@ type Conn struct {
meta meta
adapter iface.TunAdapter
iFaceDiscover stdnet.ExternalIFaceDiscover
iFaceDiscover stdnet.IFaceDiscover
}
// meta holds meta information about a connection
@@ -123,7 +121,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*Conn, error) {
return &Conn{
config: config,
mu: sync.Mutex{},
@@ -138,6 +136,32 @@ func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter
}, nil
}
// interfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them
func interfaceFilter(blackList []string) func(string) bool {
return func(iFace string) bool {
for _, s := range blackList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
return false
}
}
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
return true
}
defer func() {
_ = wg.Close()
}()
_, err = wg.Device(iFace)
return err != nil
}
}
func (conn *Conn) reCreateAgent() error {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -147,7 +171,7 @@ func (conn *Conn) reCreateAgent() error {
var err error
transportNet, err := conn.newStdNet()
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
log.Warnf("failed to create pion's stdnet: %s", err)
}
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
@@ -155,7 +179,7 @@ func (conn *Conn) reCreateAgent() error {
Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs,
@@ -295,7 +319,7 @@ func (conn *Conn) Open() error {
return err
}
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
if conn.proxy.Type() == proxy.TypeNoProxy {
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection
@@ -317,62 +341,29 @@ func (conn *Conn) Open() error {
// useProxy determines whether a direct connection (without a go proxy) is possible
//
// There are 3 cases:
// There are 2 cases:
//
// * When neither candidate is from hard nat and one of the peers has a public IP
//
// * both peers are in the same private network
//
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
//
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
if !isRelayCandidate(pair.Local) && userspaceBind {
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
return false
}
func shouldUseProxy(pair *ice.CandidatePair) bool {
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
return false
}
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
log.Debugf("shouldn't use proxy because the remote peer is not behind a hard NAT and the local one has a public IP")
return false
}
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network")
return false
}
if (isPeerReflexiveCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) ||
isHostCandidateWithPrivateIP(pair.Local) && isPeerReflexiveCandidateWithPrivateIP(pair.Remote)) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network and one peer is peer reflexive")
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) {
return false
}
return true
}
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
localIP := net.ParseIP(pair.Local.Address())
remoteIP := net.ParseIP(pair.Remote.Address())
if localIP == nil || remoteIP == nil {
return false
}
// only consider /16 networks
mask := net.IPMask{255, 255, 0, 0}
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func isHardNATCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
}
@@ -385,13 +376,9 @@ func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
}
func isPeerReflexiveCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypePeerReflexive && !isPublicIP(candidate.Address())
}
func isPublicIP(address string) bool {
ip := net.ParseIP(address)
if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
return false
}
return true
@@ -425,7 +412,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
peerState.Direct = p.Type() == proxy.TypeNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -436,7 +423,8 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
}
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
useProxy := shouldUseProxy(pair)
localDirectMode := !useProxy
remoteDirectMode := localDirectMode
@@ -446,16 +434,13 @@ func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgP
remoteDirectMode = conn.receiveRemoteDirectMode()
}
if conn.config.UserspaceBind && localDirectMode {
return proxy.NewNoProxy(conn.config.ProxyConfig)
}
if localDirectMode && remoteDirectMode {
return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
log.Debugf("using WireGuard direct mode with peer %s", conn.config.Key)
return proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
}
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
return proxy.NewWireguardProxy(conn.config.ProxyConfig)
}
func (conn *Conn) sendLocalDirectMode(localMode bool) {

View File

@@ -5,8 +5,6 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/magiconair/properties/assert"
"github.com/pion/ice/v2"
"golang.org/x/sync/errgroup"
@@ -30,7 +28,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"}
filter := stdnet.InterfaceFilter(ignore)
filter := interfaceFilter(ignore)
for _, s := range ignore {
assert.Equal(t, filter(s), false)
@@ -210,7 +208,6 @@ func TestConn_ShouldUseProxy(t *testing.T) {
return ice.CandidateTypeHost
},
}
srflxCandidate := &mockICECandidate{
AddressFunc: func() string {
return "1.1.1.1"
@@ -323,47 +320,11 @@ func TestConn_ShouldUseProxy(t *testing.T) {
},
expected: false,
},
{
name: "Don't Use Proxy When Both Candidates are in private network and one is peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: false,
},
{
name: "Should Use Proxy When Both Candidates are in private network and both are peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := shouldUseProxy(testCase.candatePair, false)
result := shouldUseProxy(testCase.candatePair)
if result != testCase.expected {
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
}
@@ -404,7 +365,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: true,
inputRemoteModeMessage: true,
expected: proxy.TypeWireGuard,
expected: proxy.TypeWireguard,
},
{
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
@@ -414,7 +375,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: true,
inputRemoteModeMessage: false,
expected: proxy.TypeWireGuard,
expected: proxy.TypeWireguard,
},
{
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
@@ -424,7 +385,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: false,
inputRemoteModeMessage: false,
expected: proxy.TypeWireGuard,
expected: proxy.TypeWireguard,
},
{
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
@@ -434,7 +395,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: false,
inputRemoteModeMessage: false,
expected: proxy.TypeDirectNoProxy,
expected: proxy.TypeNoProxy,
},
{
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
@@ -444,7 +405,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: true,
inputRemoteModeMessage: true,
expected: proxy.TypeDirectNoProxy,
expected: proxy.TypeNoProxy,
},
}
for _, testCase := range testCases {

View File

@@ -3,9 +3,9 @@
package peer
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/pion/transport/v2/stdnet"
)
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(conn.config.InterfaceBlackList)
return stdnet.NewNet()
}

View File

@@ -3,5 +3,5 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
return stdnet.NewNet(conn.iFaceDiscover)
}

View File

@@ -1,57 +0,0 @@
package proxy
import (
log "github.com/sirupsen/logrus"
"net"
)
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
// This is possible in either of these cases:
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
type DirectNoProxy struct {
config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
}
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
func NewDirectNoProxy(config Config, remoteWgPort int) *DirectNoProxy {
return &DirectNoProxy{config: config, RemoteWgListenPort: remoteWgPort}
}
// Close removes peer from the WireGuard interface
func (p *DirectNoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote IP and default WireGuard port
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using DirectNoProxy while connecting to peer %s", p.config.RemoteKey)
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
// Type returns the type of this proxy
func (p *DirectNoProxy) Type() Type {
return TypeDirectNoProxy
}

View File

@@ -5,18 +5,24 @@ import (
"net"
)
// NoProxy is used just to configure WireGuard without any local proxy in between.
// Used when the WireGuard interface is userspace and uses bind.ICEBind
// NoProxy is used when there is no need for a proxy between ICE and Wireguard.
// This is possible in either of these cases:
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// NoProxy will just update remote peer with a remote host and fixed Wireguard port (r.g. 51820).
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
type NoProxy struct {
config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
}
// NewNoProxy creates a new NoProxy with a provided config
func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config}
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port
func NewNoProxy(config Config, remoteWgPort int) *NoProxy {
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort}
}
// Close removes peer from the WireGuard interface
func (p *NoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
@@ -25,16 +31,23 @@ func (p *NoProxy) Close() error {
return nil
}
// Start just updates WireGuard peer with the remote address
// Start just updates Wireguard peer with the remote IP and default Wireguard port
func (p *NoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
log.Debugf("using NoProxy while connecting to peer %s", p.config.RemoteKey)
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
func (p *NoProxy) Type() Type {

View File

@@ -13,10 +13,9 @@ const DefaultWgKeepAlive = 25 * time.Second
type Type string
const (
TypeDirectNoProxy Type = "DirectNoProxy"
TypeWireGuard Type = "WireGuard"
TypeDummy Type = "Dummy"
TypeNoProxy Type = "NoProxy"
TypeNoProxy Type = "NoProxy"
TypeWireguard Type = "Wireguard"
TypeDummy Type = "Dummy"
)
type Config struct {

View File

@@ -6,8 +6,8 @@ import (
"net"
)
// WireGuardProxy proxies
type WireGuardProxy struct {
// WireguardProxy proxies
type WireguardProxy struct {
ctx context.Context
cancel context.CancelFunc
@@ -17,13 +17,13 @@ type WireGuardProxy struct {
localConn net.Conn
}
func NewWireGuardProxy(config Config) *WireGuardProxy {
p := &WireGuardProxy{config: config}
func NewWireguardProxy(config Config) *WireguardProxy {
p := &WireguardProxy{config: config}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *WireGuardProxy) updateEndpoint() error {
func (p *WireguardProxy) updateEndpoint() error {
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
if err != nil {
return err
@@ -38,7 +38,7 @@ func (p *WireGuardProxy) updateEndpoint() error {
return nil
}
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
func (p *WireguardProxy) Start(remoteConn net.Conn) error {
p.remoteConn = remoteConn
var err error
@@ -60,7 +60,7 @@ func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
return nil
}
func (p *WireGuardProxy) Close() error {
func (p *WireguardProxy) Close() error {
p.cancel()
if c := p.localConn; c != nil {
err := p.localConn.Close()
@@ -77,7 +77,7 @@ func (p *WireGuardProxy) Close() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WireGuardProxy) proxyToRemote() {
func (p *WireguardProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
@@ -101,7 +101,7 @@ func (p *WireGuardProxy) proxyToRemote() {
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WireGuardProxy) proxyToLocal() {
func (p *WireguardProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
@@ -123,6 +123,6 @@ func (p *WireGuardProxy) proxyToLocal() {
}
}
func (p *WireGuardProxy) Type() Type {
return TypeWireGuard
func (p *WireguardProxy) Type() Type {
return TypeWireguard
}

View File

@@ -1,15 +1,12 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
)
import "github.com/google/nftables"
const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding"

View File

@@ -1,17 +1,14 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"net/netip"
"os/exec"
"strings"
"sync"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
)
func isIptablesSupported() bool {

View File

@@ -1,13 +1,10 @@
//go:build !android
package routemanager
import (
"context"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require"
"testing"
)
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {

View File

@@ -1,130 +1,9 @@
package routemanager
import (
"context"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/version"
)
import "github.com/netbirdio/netbird/route"
// Manager is a route manager interface
type Manager interface {
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
Stop()
}
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRouter *serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRouter: newServerRouter(ctx, wgInterface),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.cleanUp()
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}

View File

@@ -0,0 +1,31 @@
package routemanager
import (
"context"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
// DefaultManager dummy router manager for Android
type DefaultManager struct {
ctx context.Context
serverRouter *serverRouter
wgInterface *iface.WGIface
}
// NewManager returns a new dummy route manager what doing nothing
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
return &DefaultManager{}
}
// UpdateRoutes ...
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
return nil
}
// Stop ...
func (m *DefaultManager) Stop() {
}

View File

@@ -0,0 +1,186 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/version"
)
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRoutes map[string]*route.Route
serverRouter *serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRoutes: make(map[string]*route.Route),
serverRouter: &serverRouter{
routes: make(map[string]*route.Route),
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
firewall: NewFirewall(ctx),
},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.firewall.CleanRoutingRules()
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.serverRouter.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.serverRoutes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
continue
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.serverRoutes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.serverRoutes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.serverRoutes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.serverRoutes[id] = newRoute
}
if len(m.serverRoutes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.updateServerRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}

View File

@@ -3,7 +3,6 @@ package routemanager
import (
"context"
"fmt"
"github.com/pion/transport/v2/stdnet"
"net/netip"
"runtime"
"testing"
@@ -392,12 +391,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
@@ -420,7 +414,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if testCase.shouldCheckServerRoutes {
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match")
}
})
}

View File

@@ -1,19 +1,16 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"net"
"net/netip"
"sync"
)
import "github.com/google/nftables"
const (
nftablesTable = "netbird-rt"

View File

@@ -1,15 +1,12 @@
//go:build !android
package routemanager
import (
"context"
"testing"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {

View File

@@ -1,24 +0,0 @@
package routemanager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}

View File

@@ -0,0 +1,67 @@
package routemanager
import (
"github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"net/netip"
"sync"
)
type serverRouter struct {
routes map[string]*route.Route
// best effort to keep net forward configuration as it was
netForwardHistoryEnabled bool
mux sync.Mutex
firewall firewallManager
}
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
delete(m.serverRouter.routes, route.ID)
return nil
}
}
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
m.serverRouter.routes[route.ID] = route
return nil
}
}

View File

@@ -1,21 +0,0 @@
package routemanager
import (
"context"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{}
}
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil
}
func (r *serverRouter) cleanUp() {}

View File

@@ -1,120 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewallManager
wgInterface *iface.WGIface
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: NewFirewall(ctx),
wgInterface: wgInterface,
}
}
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.routes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.routes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.routes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.routes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.routes[id] = newRoute
}
if len(m.routes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
delete(m.routes, route.ID)
return nil
}
}
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
m.routes[route.ID] = route
return nil
}
}
func (m *serverRouter) cleanUp() {
m.firewall.CleanRoutingRules()
}

View File

@@ -1,14 +1,11 @@
//go:build !android
package routemanager
import (
"fmt"
"net"
"net/netip"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"net"
"net/netip"
)
var errRouteNotFound = fmt.Errorf("route not found")

View File

@@ -1,13 +0,0 @@
package routemanager
import (
"net/netip"
)
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return nil
}

View File

@@ -1,13 +1,10 @@
//go:build !android
package routemanager
import (
"github.com/vishvananda/netlink"
"net"
"net/netip"
"os"
"github.com/vishvananda/netlink"
)
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
@@ -65,3 +62,12 @@ func enableIPForwarding() error {
err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
return err
}
func isNetForwardHistoryEnabled() bool {
out, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
// todo
panic(err)
}
return string(out) == "1"
}

View File

@@ -4,11 +4,10 @@
package routemanager
import (
log "github.com/sirupsen/logrus"
"net/netip"
"os/exec"
"runtime"
log "github.com/sirupsen/logrus"
)
func addToRouteTable(prefix netip.Prefix, addr string) error {
@@ -35,3 +34,8 @@ func enableIPForwarding() error {
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func isNetForwardHistoryEnabled() bool {
log.Infof("check netforward history is not implemented on %s", runtime.GOOS)
return false
}

View File

@@ -3,7 +3,6 @@ package routemanager
import (
"fmt"
"github.com/netbirdio/netbird/iface"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require"
"net"
"net/netip"
@@ -33,11 +32,7 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

View File

@@ -1,14 +0,0 @@
package stdnet
import "github.com/pion/transport/v2"
// ExternalIFaceDiscover provide an option for external services (mobile)
// to collect network interface information
type ExternalIFaceDiscover interface {
// IFaces return with the description of the interfaces
IFaces() (string, error)
}
type iFaceDiscover interface {
iFaces() ([]*transport.Interface, error)
}

View File

@@ -1,98 +0,0 @@
package stdnet
import (
"fmt"
"net"
"strings"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
)
type mobileIFaceDiscover struct {
externalDiscover ExternalIFaceDiscover
}
func newMobileIFaceDiscover(externalDiscover ExternalIFaceDiscover) *mobileIFaceDiscover {
return &mobileIFaceDiscover{
externalDiscover: externalDiscover,
}
}
func (m *mobileIFaceDiscover) iFaces() ([]*transport.Interface, error) {
ifacesString, err := m.externalDiscover.IFaces()
if err != nil {
return nil, err
}
interfaces := m.parseInterfacesString(ifacesString)
return interfaces, nil
}
func (m *mobileIFaceDiscover) parseInterfacesString(interfaces string) []*transport.Interface {
ifs := []*transport.Interface{}
for _, iface := range strings.Split(interfaces, "\n") {
if strings.TrimSpace(iface) == "" {
continue
}
fields := strings.Split(iface, "|")
if len(fields) != 2 {
log.Warnf("parseInterfacesString: unable to split %q", iface)
continue
}
var name string
var index, mtu int
var up, broadcast, loopback, pointToPoint, multicast bool
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
if err != nil {
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
continue
}
newIf := net.Interface{
Name: name,
Index: index,
MTU: mtu,
}
if up {
newIf.Flags |= net.FlagUp
}
if broadcast {
newIf.Flags |= net.FlagBroadcast
}
if loopback {
newIf.Flags |= net.FlagLoopback
}
if pointToPoint {
newIf.Flags |= net.FlagPointToPoint
}
if multicast {
newIf.Flags |= net.FlagMulticast
}
ifc := transport.NewInterface(newIf)
addrs := strings.Trim(fields[1], " \n")
foundAddress := false
for _, addr := range strings.Split(addrs, " ") {
if strings.Contains(addr, "%") {
continue
}
ip, ipNet, err := net.ParseCIDR(addr)
if err != nil {
log.Warnf("%s", err)
continue
}
ipNet.IP = ip
ifc.AddAddress(ipNet)
foundAddress = true
}
if foundAddress {
ifs = append(ifs, ifc)
}
}
return ifs
}

View File

@@ -1,36 +0,0 @@
package stdnet
import (
"net"
"github.com/pion/transport/v2"
)
type pionDiscover struct {
}
func (d pionDiscover) iFaces() ([]*transport.Interface, error) {
ifs := []*transport.Interface{}
oifs, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, oif := range oifs {
ifc := transport.NewInterface(oif)
addrs, err := oif.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
ifc.AddAddress(addr)
}
ifs = append(ifs, ifc)
}
return ifs, nil
}

View File

@@ -1,40 +0,0 @@
package stdnet
import (
"strings"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
)
// InterfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them.
func InterfaceFilter(disallowList []string) func(string) bool {
return func(iFace string) bool {
if strings.HasPrefix(iFace, "lo") {
// hardcoded loopback check to support already installed agents
return false
}
for _, s := range disallowList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
return false
}
}
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
return true
}
defer func() {
_ = wg.Close()
}()
_, err = wg.Device(iFace)
return err != nil
}
}

View File

@@ -0,0 +1,8 @@
package stdnet
// IFaceDiscover provide an option for external services (mobile)
// to collect network interface information
type IFaceDiscover interface {
// IFaces return with the description of the interfaces
IFaces() (string, error)
}

View File

@@ -5,50 +5,37 @@ package stdnet
import (
"fmt"
"net"
"strings"
"github.com/pion/transport/v2"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
)
// Net is an implementation of the net.Net interface
// based on functions of the standard net package.
type Net struct {
stdnet.Net
interfaces []*transport.Interface
iFaceDiscover iFaceDiscover
// interfaceFilter should return true if the given interfaceName is allowed
interfaceFilter func(interfaceName string) bool
}
// NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
n := &Net{
iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover),
interfaceFilter: InterfaceFilter(disallowList),
}
return n, n.UpdateInterfaces()
interfaces []*transport.Interface
}
// NewNet creates a new StdNet instance.
func NewNet(disallowList []string) (*Net, error) {
n := &Net{
iFaceDiscover: pionDiscover{},
interfaceFilter: InterfaceFilter(disallowList),
}
return n, n.UpdateInterfaces()
func NewNet(iFaceDiscover IFaceDiscover) (*Net, error) {
n := &Net{}
return n, n.UpdateInterfaces(iFaceDiscover)
}
// UpdateInterfaces updates the internal list of network interfaces
// and associated addresses filtering them by name.
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
// wasn't specified.
func (n *Net) UpdateInterfaces() (err error) {
allIfaces, err := n.iFaceDiscover.iFaces()
// and associated addresses.
func (n *Net) UpdateInterfaces(iFaceDiscover IFaceDiscover) error {
ifacesString, err := iFaceDiscover.IFaces()
if err != nil {
return err
}
n.interfaces = n.filterInterfaces(allIfaces)
return nil
n.interfaces = parseInterfacesString(ifacesString)
return err
}
// Interfaces returns a slice of interfaces which are available on the
@@ -83,15 +70,68 @@ func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name)
}
func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.Interface {
if n.interfaceFilter == nil {
return interfaces
}
result := []*transport.Interface{}
for _, iface := range interfaces {
if n.interfaceFilter(iface.Name) {
result = append(result, iface)
func parseInterfacesString(interfaces string) []*transport.Interface {
ifs := []*transport.Interface{}
for _, iface := range strings.Split(interfaces, "\n") {
if strings.TrimSpace(iface) == "" {
continue
}
fields := strings.Split(iface, "|")
if len(fields) != 2 {
log.Warnf("parseInterfacesString: unable to split %q", iface)
continue
}
var name string
var index, mtu int
var up, broadcast, loopback, pointToPoint, multicast bool
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
if err != nil {
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
continue
}
newIf := net.Interface{
Name: name,
Index: index,
MTU: mtu,
}
if up {
newIf.Flags |= net.FlagUp
}
if broadcast {
newIf.Flags |= net.FlagBroadcast
}
if loopback {
newIf.Flags |= net.FlagLoopback
}
if pointToPoint {
newIf.Flags |= net.FlagPointToPoint
}
if multicast {
newIf.Flags |= net.FlagMulticast
}
ifc := transport.NewInterface(newIf)
addrs := strings.Trim(fields[1], " \n")
foundAddress := false
for _, addr := range strings.Split(addrs, " ") {
ip, ipNet, err := net.ParseCIDR(addr)
if err != nil {
log.Warnf("%s", err)
continue
}
ipNet.IP = ip
ifc.AddAddress(ipNet)
foundAddress = true
}
if foundAddress {
ifs = append(ifs, ifc)
}
}
return result
return ifs
}

View File

@@ -3,8 +3,6 @@ package stdnet
import (
"fmt"
"testing"
log "github.com/sirupsen/logrus"
)
func Test_parseInterfacesString(t *testing.T) {
@@ -22,7 +20,6 @@ func Test_parseInterfacesString(t *testing.T) {
{"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"},
{"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"},
{"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"},
{"rmnet_data2", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2%rmnet2/64"},
}
var exampleString string
@@ -38,13 +35,11 @@ func Test_parseInterfacesString(t *testing.T) {
d.multicast,
d.addr)
}
d := mobileIFaceDiscover{}
nets := d.parseInterfacesString(exampleString)
nets := parseInterfacesString(exampleString)
if len(nets) == 0 {
t.Fatalf("failed to parse interfaces")
}
log.Printf("%d", len(nets))
for i, net := range nets {
if net.MTU != testData[i].mtu {
t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu)
@@ -63,7 +58,7 @@ func Test_parseInterfacesString(t *testing.T) {
if len(addr) == 0 {
t.Errorf("invalid address parsing")
}
log.Printf("%v", addr)
if addr[0].String() != testData[i].addr {
t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr)
}

View File

@@ -6,4 +6,4 @@
#define EXPAND(x) STRINGIZE(x)
CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
7 ICON ui/netbird.ico
wintun.dll RCDATA wintun.dll
wireguard.dll RCDATA wireguard.dll

27
client/wireguard_nt.sh Normal file
View File

@@ -0,0 +1,27 @@
#!/bin/bash
ldir=$PWD
tmp_dir_path=$ldir/.distfiles
winnt=wireguard-nt.zip
download_file_path=$tmp_dir_path/$winnt
download_url=https://download.wireguard.com/wireguard-nt/wireguard-nt-0.10.1.zip
download_sha=772c0b1463d8d2212716f43f06f4594d880dea4f735165bd68e388fc41b81605
function resources_windows(){
cmd=$1
arch=$2
out=$3
docker run -i --rm -v $PWD:$PWD -w $PWD mstorsjo/llvm-mingw:latest $cmd -O coff -c 65001 -I $tmp_dir_path/wireguard-nt/bin/$arch -i resources.rc -o $out
}
mkdir -p $tmp_dir_path
curl -L#o $download_file_path.unverified $download_url
echo "$download_sha $download_file_path.unverified" | sha256sum -c
mv $download_file_path.unverified $download_file_path
mkdir -p .deps
unzip $download_file_path -d $tmp_dir_path
resources_windows i686-w64-mingw32-windres x86 resources_windows_386.syso
resources_windows aarch64-w64-mingw32-windres arm64 resources_windows_arm64.syso
resources_windows x86_64-w64-mingw32-windres amd64 resources_windows_amd64.syso

7
go.mod
View File

@@ -19,7 +19,7 @@ require (
github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.7.0
golang.org/x/sys v0.6.0
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.52.3
@@ -48,8 +48,6 @@ require (
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/open-policy-agent/opa v0.49.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
github.com/pion/stun v0.4.0
github.com/pion/transport/v2 v2.0.2
github.com/prometheus/client_golang v1.14.0
github.com/rs/xid v1.3.0
@@ -105,8 +103,10 @@ require (
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.6 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/stun v0.4.0 // indirect
github.com/pion/turn/v2 v2.1.0 // indirect
github.com/pion/udp/v2 v2.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -131,6 +131,7 @@ require (
golang.org/x/mod v0.8.0 // indirect
golang.org/x/text v0.8.0 // indirect
golang.org/x/tools v0.6.0 // indirect
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d // indirect
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect

7
go.sum
View File

@@ -881,14 +881,13 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202 h1:KcD4X7IcoRdQpr9NSQpQpn5S4rUMIQaCdF90FOEj6KY=
golang.zx2c4.com/wireguard v0.0.0-20230310135217-9e2f38602202/go.mod h1:qc3aHNhM1Rc4hW2az896MjLVcxHvLbJ6LZc9MI7RTMY=
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 h1:3zl8RkJNQ8wfPRomwv/6DBbH2Ut6dgMaWTxM0ZunWnE=
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=

View File

@@ -1,425 +0,0 @@
package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"syscall"
"github.com/pion/stun"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
var (
_ wgConn.Bind = (*ICEBind)(nil)
)
// ICEBind implements Bind for all platforms except Windows.
type ICEBind struct {
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
batchSize int
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
// NetBird related variables
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
worker *worker
}
func NewICEBind(transportNet transport.Net) *ICEBind {
b := &ICEBind{
batchSize: wgConn.DefaultBatchSize,
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, wgConn.DefaultBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, wgConn.DefaultBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
transportNet: transportNet,
}
b.worker = newWorker(b.handlePkgs)
return b
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if supported.
src struct {
netip.Addr
ifidx int32
}
}
var (
_ wgConn.Bind = (*ICEBind)(nil)
_ wgConn.Endpoint = &StdNetEndpoint{}
)
func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
e, err := netip.ParseAddrPort(s)
return asEndpoint(e), err
}
func (e *StdNetEndpoint) ClearSrc() {
e.src.ifidx = 0
e.src.Addr = netip.Addr{}
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return e.src.Addr
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return e.src.ifidx
}
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
func (e *StdNetEndpoint) SrcToString() string {
return e.src.Addr.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var tries int
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, wgConn.ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
return nil, 0, err
}
var fns []wgConn.ReceiveFunc
if v4conn != nil {
fns = append(fns, s.receiveIPv4)
s.ipv4 = v4conn
}
if v6conn != nil {
fns = append(fns, s.receiveIPv6)
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
s.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: s.ipv4, Net: s.transportNet})
return fns, uint16(port), nil
}
func (s *ICEBind) receiveIPv4(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
s.worker.doWork((*msgs)[:numMsgs], sizes, eps)
return numMsgs, nil
}
func (s *ICEBind) receiveIPv6(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv6MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *ICEBind) BatchSize() int {
return s.batchSize
}
func (s *ICEBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
}
s.blackhole4 = false
s.blackhole6 = false
if err1 != nil {
return err1
}
return err2
}
func (s *ICEBind) Send(buffs [][]byte, endpoint wgConn.Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
is6 = true
}
s.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
if is6 {
return s.send6(s.ipv6PC, endpoint, buffs)
} else {
return s.send4(s.ipv4PC, endpoint, buffs)
}
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (s *ICEBind) send4(conn *ipv4.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
func (s *ICEBind) send6(conn *ipv6.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for _, buffer := range buffers {
if !stun.IsMessage(buffer) {
continue
}
msg, err := parseSTUNMessage(buffer[:n])
if err != nil {
buffer = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle packet")
}
buffer = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) handlePkgs(msg *ipv4.Message) (int, *StdNetEndpoint) {
// todo: handle err
size := 0
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if !ok {
size = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
return size, ep
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]*StdNetEndpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = &StdNetEndpoint{AddrPort: ap}
m[ap] = e
}
return e
}
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}

View File

@@ -1,36 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"net"
"syscall"
)
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
}

View File

@@ -1,41 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"fmt"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
case "udp6":
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

View File

@@ -1,28 +0,0 @@
//go:build !windows && !linux && !js
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
var err error
if network == "udp6" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
}
return err
},
)
}

View File

@@ -1,12 +0,0 @@
//go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
func (s *ICEBind) SetMark(mark uint32) error {
return nil
}

View File

@@ -1,65 +0,0 @@
//go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"runtime"
"golang.org/x/sys/unix"
)
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func (s *ICEBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
if s.ipv4 != nil {
fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
if s.ipv6 != nil {
fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,28 +0,0 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
// use alternatively named flags and need ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *wgConn.StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *wgConn.StdNetEndpoint) {
}
// srcControlSize returns the recommended buffer size for pooling sticky control
// data.
const srcControlSize = 0

View File

@@ -1,111 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
ep.src.ifidx = info.Ifindex
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
ep.src.Addr = netip.AddrFrom16(info.Addr)
ep.src.ifidx = int32(info.Ifindex)
return
}
}
}
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
// panics if buf is of insufficient size.
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
size := int(unsafe.Sizeof(t))
if len(buf) < size {
panic("pktInfoFromBuf: buffer too small")
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
return t
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = (*control)[:cap(*control)]
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
*control = (*control)[:0]
return
}
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
*control = (*control)[:0]
return
}
if len(*control) < srcControlSize {
*control = (*control)[:0]
return
}
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
if ep.SrcIP().Is4() {
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = ep.src.ifidx
if ep.SrcIP().IsValid() {
info.Spec_dst = ep.SrcIP().As4()
}
} else {
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = uint32(ep.src.ifidx)
if ep.SrcIP().IsValid() {
info.Addr = ep.SrcIP().As16()
}
}
*control = (*control)[:hdr.Len]
}
var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)

View File

@@ -1,207 +0,0 @@
//go:build linux
// +build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bind
import (
"context"
"net"
"net/netip"
"runtime"
"testing"
"unsafe"
"golang.org/x/sys/unix"
)
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
ep.src.Addr = netip.MustParseAddr("127.0.0.1")
ep.src.ifidx = 5
control := make([]byte, srcControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IP {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IP_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
t.Errorf("unexpected address: %v", info.Spec_dst)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("IPv6", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
ep.src.Addr = netip.MustParseAddr("::1")
ep.src.ifidx = 5
control := make([]byte, srcControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IPV6 {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IPV6_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Addr != ep.SrcIP().As16() {
t.Errorf("unexpected address: %v", info.Addr)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
hdr.Len = 3
setSrcControl(&control, &StdNetEndpoint{})
if len(control) != 0 {
t.Errorf("unexpected control: %v", control)
}
})
}
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.src.Addr)
}
if ep.src.ifidx != 5 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.src.ifidx != 5 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
control := make([]byte, srcControlSize)
ep := &StdNetEndpoint{}
ep.src.Addr = netip.MustParseAddr("::1")
ep.src.ifidx = 5
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
t.Errorf("unexpected address: %v", ep.src.Addr)
}
if ep.src.ifidx != 0 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
}
func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IP_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IPV6_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
}

View File

@@ -1,445 +0,0 @@
package bind
import (
"fmt"
"io"
"net"
"strings"
"sync"
"github.com/pion/ice/v2"
"github.com/pion/stun"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/transport/v2"
)
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
mu sync.Mutex
// for UDP connection listen at unspecified address
localAddrsForUnspecified []net.Addr
}
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
// Required for gathering local addresses
// in case a un UDPConn is passed which does not
// bind to a specific local address.
Net transport.Net
InterfaceFilter func(interfaceName string) bool
}
func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []ice.NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
ips := []net.IP{}
ifaces, err := n.Interfaces()
if err != nil {
return ips, err
}
var IPv4Requested, IPv6Requested bool
for _, typ := range networkTypes {
if typ.IsIPv4() {
IPv4Requested = true
}
if typ.IsIPv6() {
IPv6Requested = true
}
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
continue // loopback interface
}
if interfaceFilter != nil && !interfaceFilter(iface.Name) {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch addr := addr.(type) {
case *net.IPNet:
ip = addr.IP
case *net.IPAddr:
ip = addr.IP
}
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
continue
}
if ipv4 := ip.To4(); ipv4 == nil {
if !IPv6Requested {
continue
} else if !isSupportedIPv6(ip) {
continue
}
} else if !IPv4Requested {
continue
}
if ipFilter != nil && !ipFilter(ip) {
continue
}
ips = append(ips, ip)
}
}
return ips, nil
}
// The conditions of invalidation written below are defined in
// https://tools.ietf.org/html/rfc8445#section-5.1.1.1
func isSupportedIPv6(ip net.IP) bool {
if len(ip) != net.IPv6len ||
isZeros(ip[0:12]) || // !(IPv4-compatible IPv6)
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast)
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() {
return false
}
return true
}
func isZeros(ip net.IP) bool {
for i := 0; i < len(ip); i++ {
if ip[i] != 0 {
return false
}
}
return true
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
var localAddrsForUnspecified []net.Addr
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
localAddrsForUnspecified: localAddrsForUnspecified,
}
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified
}
return []net.Addr{m.LocalAddr()}
}
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
m.mu.Lock()
defer m.mu.Unlock()
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
}
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
}
m.mu.Unlock()
if len(removedConns) == 0 {
// No need to lock if no connection was found
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != ufrag {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
default:
return false
}
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, c := range m.connsIPv4 {
_ = c.Close()
}
for _, c := range m.connsIPv6 {
_ = c.Close()
}
m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)
close(m.closedChan)
_ = m.params.UDPConn.Close()
})
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.Unlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}
type bufferHolder struct {
buf []byte
}
func newBufferHolder(size int) *bufferHolder {
return &bufferHolder{
buf: make([]byte, size),
}
}

View File

@@ -1,254 +0,0 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
*/
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/v2"
)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped
}
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25
}
m := &UniversalUDPMuxDefault{
params: params,
xorMappedMap: make(map[string]*xorMapped),
}
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
}
// GetListenAddresses returns the listen addr of this UDP
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}
}
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
// Not implemented yet.
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
return nil, fmt.Errorf("not implemented yet")
}
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
// All other STUN packets will be forwarded to the UDPMux
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {
log.Debugf("%s: %v", fmt.Errorf("failed to get XOR-MAPPED-ADDRESS response"), err)
return nil
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock()
defer m.mu.Unlock()
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
_, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok
}
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
// and set the mapped address for the server
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock()
defer m.mu.Unlock()
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
if !ok {
return fmt.Errorf("no XOR address mapping")
}
var addr stun.XORMappedAddress
if err := addr.GetFrom(msg); err != nil {
return err
}
m.xorMappedMap[stunAddr.String()] = mappedAddr
mappedAddr.SetAddr(&addr)
return nil
}
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
// Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// if we already have a mapping for this STUN server (address already received)
// and if it is not too old we return it without making a new request to STUN server
if ok {
if mappedAddr.expired() {
mappedAddr.closeWaiters()
delete(m.xorMappedMap, serverAddr.String())
ok = false
} else if mappedAddr.pending() {
ok = false
}
}
m.mu.Unlock()
if ok {
return mappedAddr.addr, nil
}
// otherwise, make a STUN request to discover the address
// or wait for already sent request to complete
waitAddrReceived, err := m.sendStun(serverAddr)
if err != nil {
return nil, fmt.Errorf("%s: %s", "failed to send STUN packet", err)
}
// block until response was handled by the connWorker routine and XORMappedAddress was updated
select {
case <-waitAddrReceived:
// when channel closed, addr was obtained
m.mu.Lock()
mappedAddr := *m.xorMappedMap[serverAddr.String()]
m.mu.Unlock()
if mappedAddr.addr == nil {
return nil, fmt.Errorf("no XOR address mapping")
}
return mappedAddr.addr, nil
case <-time.After(deadline):
return nil, fmt.Errorf("timeout while waiting for XORMappedAddr")
}
}
// sendStun sends a STUN request via UDP conn.
//
// The returned channel is closed when the STUN response has been received.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) {
m.mu.Lock()
defer m.mu.Unlock()
// if record present in the map, we already sent a STUN request,
// just wait when waitAddrReceived will be closed
addrMap, ok := m.xorMappedMap[serverAddr.String()]
if !ok {
addrMap = &xorMapped{
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
waitAddrReceived: make(chan struct{}),
}
m.xorMappedMap[serverAddr.String()] = addrMap
}
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
if err != nil {
return nil, err
}
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
return nil, err
}
return addrMap.waitAddrReceived, nil
}
type xorMapped struct {
addr *stun.XORMappedAddress
waitAddrReceived chan struct{}
expiresAt time.Time
}
func (a *xorMapped) closeWaiters() {
select {
case <-a.waitAddrReceived:
// notify was close, ok, that means we received duplicate response
// just exit
break
default:
// notify tha twe have a new addr
close(a.waitAddrReceived)
}
}
func (a *xorMapped) pending() bool {
return a.addr == nil
}
func (a *xorMapped) expired() bool {
return a.expiresAt.Before(time.Now())
}
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
a.addr = addr
a.closeWaiters()
}

View File

@@ -1,233 +0,0 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
import (
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v2/packetio"
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
buf *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
}
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
buf: packetio.NewBuffer(),
closedChan: make(chan struct{}),
}
return p
}
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// read address
total, err := c.buf.Read(buf.buf)
if err != nil {
return 0, nil, err
}
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
}
// read data and then address
offset := 2
copy(b, buf.buf[offset:offset+dataLen])
offset += dataLen
// read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
offset += 2
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
return 0, nil, err
}
return dataLen, rAddr, nil
}
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
}
// each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
}
return c.params.Mux.writeTo(buf, rAddr)
}
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
}
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
}
func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buf.Close()
close(c.closedChan)
})
return err
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
}
func (c *udpMuxedConn) getAddresses() []string {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
func (c *udpMuxedConn) addAddress(addr string) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
}
func (c *udpMuxedConn) containsAddress(addr string) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
}
}
return false
}
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buf) < len(data)+maxAddrSize {
return io.ErrShortBuffer
}
// data len
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
offset := 2
// data
copy(buf.buf[offset:], data)
offset += len(data)
// write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
if err != nil {
return err
}
total := offset + n + 2
// address len
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
return err
}
return nil
}
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
ipData, err := addr.IP.MarshalText()
if err != nil {
return 0, err
}
total := 2 + len(ipData) + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
}
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
offset := 2
n := copy(buf[offset:], ipData)
offset += n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2
copy(buf[offset:], addr.Zone)
return total, nil
}
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// basic bounds checking
if ipLen+offset > len(buf) {
return nil, io.ErrShortBuffer
}
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err
}
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)
return &addr, nil
}

View File

@@ -1,56 +0,0 @@
package bind
import (
"runtime"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// todo: add close function
type worker struct {
jobOffer chan int
numOfWorker int
jobFn func(msg *ipv4.Message) (int, *StdNetEndpoint)
messages []ipv4.Message
sizes []int
eps []wgConn.Endpoint
}
func newWorker(jobFn func(msg *ipv4.Message) (int, *StdNetEndpoint)) *worker {
w := &worker{
jobOffer: make(chan int),
numOfWorker: runtime.NumCPU(),
jobFn: jobFn,
}
w.populateWorkers()
return w
}
func (w *worker) doWork(messages []ipv4.Message, sizes []int, eps []wgConn.Endpoint) {
w.messages = messages
w.sizes = sizes
w.eps = eps
for i := 0; i < len(messages); i++ {
w.jobOffer <- i
}
}
func (w *worker) populateWorkers() {
for i := 0; i < w.numOfWorker; i++ {
go w.loop()
}
}
func (w *worker) loop() {
for {
select {
case msgPos := <-w.jobOffer:
w.sizes[msgPos], w.eps[msgPos] = w.jobFn(&w.messages[msgPos])
}
}
}

View File

@@ -5,8 +5,6 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -18,20 +16,9 @@ const (
// WGIface represents a interface instance
type WGIface struct {
tun *tunDevice
configurer wGConfigurer
mu sync.Mutex
userspaceBind bool
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind
}
// GetBind returns a userspace implementation of WireGuard Bind interface
func (w *WGIface) GetBind() *bind.ICEBind {
return w.tun.iceBind
tun *tunDevice
configurer wGConfigurer
mu sync.Mutex
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
@@ -39,7 +26,7 @@ func (w *WGIface) GetBind() *bind.ICEBind {
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
log.Debugf("create Wireguard interface %s", w.tun.DeviceName())
return w.tun.Create()
}

View File

@@ -1,28 +1,22 @@
package iface
import (
"sync"
import "sync"
"github.com/pion/transport/v2"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
// NewWGIFace Creates a new Wireguard interface instance
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) {
wgIface := &WGIface{
mu: sync.Mutex{},
}
wgAddress, err := parseWGAddress(address)
if err != nil {
return wgIFace, err
return wgIface, err
}
tun := newTunDevice(wgAddress, mtu, routes, tunAdapter, transportNet)
wgIFace.tun = tun
tun := newTunDevice(wgAddress, mtu, tunAdapter)
wgIface.tun = tun
wgIFace.configurer = newWGConfigurer(tun)
wgIface.configurer = newWGConfigurer(tun)
wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
return wgIface, nil
}

View File

@@ -2,26 +2,21 @@
package iface
import (
"sync"
import "sync"
"github.com/pion/transport/v2"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
// NewWGIFace Creates a new Wireguard interface instance
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) {
wgIface := &WGIface{
mu: sync.Mutex{},
}
wgAddress, err := parseWGAddress(address)
if err != nil {
return wgIFace, err
return wgIface, err
}
wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet)
wgIface.tun = newTunDevice(ifaceName, wgAddress, mtu)
wgIFace.configurer = newWGConfigurer(iFaceName)
wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
wgIface.configurer = newWGConfigurer(ifaceName)
return wgIface, nil
}

View File

@@ -2,15 +2,13 @@ package iface
import (
"fmt"
"net"
"testing"
"time"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"testing"
"time"
)
// keep darwin compability
@@ -34,12 +32,7 @@ func init() {
func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, nil, newNet)
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}
@@ -99,11 +92,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}
@@ -132,11 +121,7 @@ func Test_CreateInterface(t *testing.T) {
func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}
@@ -164,11 +149,7 @@ func Test_Close(t *testing.T) {
func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}
@@ -215,11 +196,7 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}
@@ -278,11 +255,7 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}
@@ -331,11 +304,8 @@ func Test_ConnectPeers(t *testing.T) {
peer2Key, _ := wgtypes.GeneratePrivateKey()
keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, nil, newNet)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}
@@ -352,11 +322,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, nil, newNet)
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -3,7 +3,7 @@
package iface
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
return false
}

View File

@@ -7,6 +7,9 @@ import (
"bufio"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"io"
"io/fs"
"math"
@@ -14,10 +17,6 @@ import (
"path/filepath"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
// Holds logic to check existence of kernel modules used by wireguard interfaces
@@ -34,7 +33,6 @@ const (
loading
live
inuse
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
)
type module struct {
@@ -83,15 +81,9 @@ func tunModuleIsLoaded() bool {
return tunLoaded
}
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if canCreateFakeWireGuardInterface() {
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
if canCreateFakeWireguardInterface() {
return true
}
@@ -104,7 +96,7 @@ func WireGuardModuleIsLoaded() bool {
return loaded
}
func canCreateFakeWireGuardInterface() bool {
func canCreateFakeWireguardInterface() bool {
link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid

View File

@@ -3,14 +3,13 @@ package iface
import (
"bufio"
"bytes"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func TestGetModuleDependencies(t *testing.T) {

View File

@@ -2,6 +2,6 @@ package iface
// TunAdapter is an interface for create tun device from externel service
type TunAdapter interface {
ConfigureInterface(address string, mtu int, routes string) (int, error)
ConfigureInterface(address string, mtu int) (int, error)
UpdateAddr(address string) error
}

View File

@@ -1,43 +1,38 @@
package iface
import (
"strings"
"net"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind"
)
type tunDevice struct {
address WGAddress
mtu int
routes []string
tunAdapter TunAdapter
fd int
name string
device *device.Device
iceBind *bind.ICEBind
fd int
name string
device *device.Device
uapi net.Listener
}
func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter) *tunDevice {
return &tunDevice{
address: address,
mtu: mtu,
routes: routes,
tunAdapter: tunAdapter,
iceBind: bind.NewICEBind(transportNet),
}
}
func (t *tunDevice) Create() error {
var err error
routesString := t.routesToString()
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, routesString)
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
return err
@@ -51,11 +46,35 @@ func (t *tunDevice) Create() error {
t.name = name
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device = device.NewDevice(tunDevice, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device.DisableSomeRoamingForBrokenMobileSemantics()
log.Debugf("create uapi")
tunSock, err := ipc.UAPIOpen(name)
if err != nil {
return err
}
t.uapi, err = ipc.UAPIListen(name, tunSock)
if err != nil {
tunSock.Close()
unix.Close(t.fd)
return err
}
go func() {
for {
uapiConn, err := t.uapi.Accept()
if err != nil {
return
}
go t.device.IpcHandle(uapiConn)
}
}()
err = t.device.Up()
if err != nil {
tunSock.Close()
t.device.Close()
return err
}
@@ -81,13 +100,13 @@ func (t *tunDevice) UpdateAddr(addr WGAddress) error {
}
func (t *tunDevice) Close() (err error) {
if t.uapi != nil {
err = t.uapi.Close()
}
if t.device != nil {
t.device.Close()
}
return
}
func (t *tunDevice) routesToString() string {
return strings.Join(t.routes, ";")
}

View File

@@ -11,7 +11,7 @@ import (
)
func (c *tunDevice) Create() error {
if WireGuardModuleIsLoaded() {
if WireguardModuleIsLoaded() {
log.Info("using kernel WireGuard")
return c.createWithKernel()
}
@@ -30,7 +30,7 @@ func (c *tunDevice) Create() error {
}
// createWithKernel Creates a new WireGuard interface using kernel WireGuard module.
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module.
// Works for Linux and offers much better network performance
func (c *tunDevice) createWithKernel() error {

View File

@@ -6,13 +6,10 @@ import (
"net"
"os"
"github.com/pion/transport/v2"
"golang.zx2c4.com/wireguard/ipc"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
@@ -21,18 +18,13 @@ type tunDevice struct {
address WGAddress
mtu int
netInterface NetInterface
iceBind *bind.ICEBind
uapi net.Listener
close chan struct{}
}
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice {
return &tunDevice{
name: name,
address: address,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
close: make(chan struct{}),
}
}
@@ -50,38 +42,23 @@ func (c *tunDevice) DeviceName() string {
}
func (c *tunDevice) Close() error {
select {
case c.close <- struct{}{}:
default:
if c.netInterface == nil {
return nil
}
var err1, err2, err3 error
if c.netInterface != nil {
err1 = c.netInterface.Close()
}
if c.uapi != nil {
err2 = c.uapi.Close()
err := c.netInterface.Close()
if err != nil {
return err
}
sockPath := "/var/run/wireguard/" + c.name + ".sock"
if _, statErr := os.Stat(sockPath); statErr == nil {
statErr = os.Remove(sockPath)
if statErr != nil {
err3 = statErr
return statErr
}
}
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
return nil
}
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
@@ -92,36 +69,26 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
}
// We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDev.Up()
tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
err = tunDevice.Up()
if err != nil {
_ = tunIface.Close()
return nil, err
return tunIface, err
}
c.uapi, err = c.getUAPI(c.name)
// todo: after this line in case of error close the tunSock
uapi, err := c.getUAPI(c.name)
if err != nil {
_ = tunIface.Close()
return nil, err
return tunIface, err
}
go func() {
for {
select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
uapiConn, uapiErr := uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr)
continue
}
go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
go tunDevice.IpcHandle(uapiConn)
}
}()

View File

@@ -4,39 +4,24 @@ import (
"fmt"
"net"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/iface/bind"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/driver"
)
type tunDevice struct {
name string
address WGAddress
netInterface NetInterface
iceBind *bind.ICEBind
mtu int
uapi net.Listener
close chan struct{}
}
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
return &tunDevice{
name: name,
address: address,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
close: make(chan struct{}),
}
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice {
return &tunDevice{name: name, address: address}
}
func (c *tunDevice) Create() error {
var err error
c.netInterface, err = c.createWithUserspace()
c.netInterface, err = c.createAdapter()
if err != nil {
return err
}
@@ -44,51 +29,6 @@ func (c *tunDevice) Create() error {
return c.assignAddr()
}
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (c *tunDevice) createWithUserspace() (NetInterface, error) {
tunIface, err := tun.CreateTUN(c.name, c.mtu)
if err != nil {
return nil, err
}
// We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDev.Up()
if err != nil {
_ = tunIface.Close()
return nil, err
}
c.uapi, err = c.getUAPI(c.name)
if err != nil {
_ = tunIface.Close()
return nil, err
}
go func() {
for {
select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr)
continue
}
go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
}
}()
log.Debugln("UAPI listener started")
return tunIface, nil
}
func (c *tunDevice) UpdateAddr(address WGAddress) error {
c.address = address
return c.assignAddr()
@@ -103,33 +43,19 @@ func (c *tunDevice) DeviceName() string {
}
func (c *tunDevice) Close() error {
select {
case c.close <- struct{}{}:
default:
if c.netInterface == nil {
return nil
}
var err1, err2 error
if c.netInterface != nil {
err1 = c.netInterface.Close()
}
if c.uapi != nil {
err2 = c.uapi.Close()
}
if err1 != nil {
return err1
}
return err2
return c.netInterface.Close()
}
func (c *tunDevice) getInterfaceGUIDString() (string, error) {
if c.netInterface == nil {
return "", fmt.Errorf("interface has not been initialized yet")
}
windowsDevice := c.netInterface.(*tun.NativeTun)
luid := winipcfg.LUID(windowsDevice.LUID())
windowsDevice := c.netInterface.(*driver.Adapter)
luid := windowsDevice.LUID()
guid, err := luid.GUID()
if err != nil {
return "", err
@@ -137,15 +63,31 @@ func (c *tunDevice) getInterfaceGUIDString() (string, error) {
return guid.String(), nil
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (c *tunDevice) assignAddr() error {
tunDev := c.netInterface.(*tun.NativeTun)
luid := winipcfg.LUID(tunDev.LUID())
log.Debugf("adding address %s to interface: %s", c.address.IP, c.name)
return luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
func (c *tunDevice) createAdapter() (NetInterface, error) {
WintunStaticRequestedGUID, _ := windows.GenerateGUID()
adapter, err := driver.CreateAdapter(c.name, "WireGuard", &WintunStaticRequestedGUID)
if err != nil {
err = fmt.Errorf("error creating adapter: %w", err)
return nil, err
}
err = adapter.SetAdapterState(driver.AdapterStateUp)
if err != nil {
return adapter, err
}
state, _ := adapter.LUID().GUID()
log.Debugln("device guid: ", state.String())
return adapter, nil
}
// getUAPI returns a Listener
func (c *tunDevice) getUAPI(iface string) (net.Listener, error) {
return ipc.UAPIListen(iface)
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (c *tunDevice) assignAddr() error {
luid := c.netInterface.(*driver.Adapter).LUID()
log.Debugf("adding address %s to interface: %s", c.address.IP, c.name)
err := luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
if err != nil {
return err
}
return nil
}

View File

@@ -172,49 +172,6 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil
}
// GetRoutes return with the routes
func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
}
defer func() {
_ = stream.CloseSend()
}()
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return nil, err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return nil, err
}
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return nil, err
}
if decryptedResp.GetNetworkMap() == nil {
return nil, fmt.Errorf("invalid msg, required network map")
}
return decryptedResp.GetNetworkMap().GetRoutes(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}