mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
9 Commits
proxy_cfg_
...
v0.17.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fec0c682e | ||
|
|
c2e90a2a97 | ||
|
|
118880b6f7 | ||
|
|
bb147c2a7c | ||
|
|
4616bc5258 | ||
|
|
1803cf3678 | ||
|
|
9f35a7fb8d | ||
|
|
2eeed55c18 | ||
|
|
0343c5f239 |
4
.github/workflows/golang-test-darwin.yml
vendored
4
.github/workflows/golang-test-darwin.yml
vendored
@@ -6,6 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: macos-latest
|
||||
|
||||
6
.github/workflows/golang-test-linux.yml
vendored
6
.github/workflows/golang-test-linux.yml
vendored
@@ -6,6 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
@@ -66,7 +70,7 @@ jobs:
|
||||
run: go mod tidy
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: go test -c -o iface-testing.bin ./iface/...
|
||||
run: go test -c -o iface-testing.bin ./iface/
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||
|
||||
60
.github/workflows/golang-test-windows.yml
vendored
60
.github/workflows/golang-test-windows.yml
vendored
@@ -6,47 +6,45 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
pre:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- run: bash -x wireguard_nt.sh
|
||||
working-directory: client
|
||||
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: syso
|
||||
path: client/*.syso
|
||||
retention-days: 1
|
||||
|
||||
test:
|
||||
needs: pre
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
id: go
|
||||
with:
|
||||
go-version: 1.19.x
|
||||
|
||||
- uses: actions/cache@v2
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
path: |
|
||||
%LocalAppData%\go-build
|
||||
~\go\pkg\mod
|
||||
~\AppData\Local\go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
|
||||
- uses: actions/download-artifact@v2
|
||||
with:
|
||||
name: syso
|
||||
path: iface\
|
||||
- name: Decompressing wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Test
|
||||
run: go test -tags=load_wgnt_from_rsrc -timeout 5m -p 1 ./...
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
- run: choco install -y sysinternals
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
3
.github/workflows/golangci-lint.yml
vendored
3
.github/workflows/golangci-lint.yml
vendored
@@ -1,5 +1,8 @@
|
||||
name: golangci-lint
|
||||
on: [pull_request]
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
golangci:
|
||||
name: lint
|
||||
|
||||
4
.github/workflows/install-test-darwin.yml
vendored
4
.github/workflows/install-test-darwin.yml
vendored
@@ -7,7 +7,9 @@ on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: macos-latest
|
||||
|
||||
4
.github/workflows/install-test-linux.yml
vendored
4
.github/workflows/install-test-linux.yml
vendored
@@ -7,7 +7,9 @@ on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
21
.github/workflows/release.yml
vendored
21
.github/workflows/release.yml
vendored
@@ -9,9 +9,13 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.5"
|
||||
SIGN_PIPE_VER: "v0.0.6"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -21,10 +25,6 @@ jobs:
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Generate syso with DLL
|
||||
run: bash -x wireguard_nt.sh
|
||||
working-directory: client
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
@@ -59,6 +59,17 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc amd64
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows rsrc arm64
|
||||
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
|
||||
- name: Generate windows rsrc arm
|
||||
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
|
||||
- name: Generate windows rsrc 386
|
||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v2
|
||||
|
||||
@@ -6,6 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -26,7 +26,7 @@ type TunAdapter interface {
|
||||
|
||||
// IFaceDiscover export internal IFaceDiscover for mobile
|
||||
type IFaceDiscover interface {
|
||||
stdnet.IFaceDiscover
|
||||
stdnet.ExternalIFaceDiscover
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
@@ -32,6 +33,11 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
if hostName != "" {
|
||||
// nolint
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if logFile == "console" {
|
||||
err = handleRebrand(cmd)
|
||||
|
||||
@@ -45,6 +45,7 @@ var (
|
||||
managementURL string
|
||||
adminURL string
|
||||
setupKey string
|
||||
hostName string
|
||||
preSharedKey string
|
||||
natExternalIPs []string
|
||||
customDNSAddress string
|
||||
@@ -94,6 +95,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
rootCmd.AddCommand(downCmd)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -55,6 +56,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
if hostName != "" {
|
||||
// nolint
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
|
||||
if foregroundMode {
|
||||
return runInForegroundMode(ctx, cmd)
|
||||
}
|
||||
|
||||
@@ -193,6 +193,7 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe`
|
||||
Sleep 3000
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
SetShellVarContext current
|
||||
|
||||
@@ -27,7 +27,7 @@ const (
|
||||
)
|
||||
|
||||
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-"}
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) error {
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -108,7 +108,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
|
||||
@@ -144,13 +144,19 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, statusRecorder)
|
||||
md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
@@ -194,13 +200,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*EngineConfig, error) {
|
||||
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
|
||||
@@ -9,7 +9,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
@@ -199,7 +202,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
@@ -47,10 +46,6 @@ var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
// TunAdapter is option. It is necessary for mobile version.
|
||||
TunAdapter iface.TunAdapter
|
||||
|
||||
IFaceDiscover stdnet.IFaceDiscover
|
||||
|
||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||
WgAddr string
|
||||
@@ -90,7 +85,9 @@ type Engine struct {
|
||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||
syncMsgMux *sync.Mutex
|
||||
|
||||
config *EngineConfig
|
||||
config *EngineConfig
|
||||
mobileDep MobileDependency
|
||||
|
||||
// STUNs is a list of STUN servers used by ICE
|
||||
STUNs []*ice.URL
|
||||
// TURNs is a list of STUN servers used by ICE
|
||||
@@ -130,7 +127,7 @@ type Peer struct {
|
||||
func NewEngine(
|
||||
ctx context.Context, cancel context.CancelFunc,
|
||||
signalClient signal.Client, mgmClient mgm.Client,
|
||||
config *EngineConfig, statusRecorder *peer.Status,
|
||||
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
ctx: ctx,
|
||||
@@ -140,6 +137,7 @@ func NewEngine(
|
||||
peerConns: make(map[string]*peer.Conn),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*ice.URL{},
|
||||
TURNs: []*ice.URL{},
|
||||
networkSerial: 0,
|
||||
@@ -166,68 +164,77 @@ func (e *Engine) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start creates a new Wireguard tunnel interface and listens to events from Signal and Management services
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
func (e *Engine) Start() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
wgIfaceName := e.config.WgIfaceName
|
||||
wgIFaceName := e.config.WgIfaceName
|
||||
wgAddr := e.config.WgAddr
|
||||
myPrivateKey := e.config.WgPrivateKey
|
||||
var err error
|
||||
|
||||
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
networkName := "udp"
|
||||
if e.config.DisableIPv6Discovery {
|
||||
networkName = "udp4"
|
||||
}
|
||||
|
||||
transportNet, err := e.newStdNet()
|
||||
if err != nil {
|
||||
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
||||
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
||||
e.close()
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
|
||||
return err
|
||||
}
|
||||
udpMuxParams := ice.UDPMuxParams{
|
||||
UDPConn: e.udpMuxConn,
|
||||
Net: transportNet,
|
||||
}
|
||||
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
|
||||
|
||||
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
|
||||
|
||||
err = e.wgInterface.Create()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
|
||||
if err != nil {
|
||||
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
|
||||
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
if e.wgInterface.IsUserspaceBind() {
|
||||
iceBind := e.wgInterface.GetBind()
|
||||
udpMux, err := iceBind.GetICEMux()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
e.udpMux = udpMux.UDPMuxDefault
|
||||
e.udpMuxSrflx = udpMux
|
||||
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
|
||||
} else {
|
||||
networkName := "udp"
|
||||
if e.config.DisableIPv6Discovery {
|
||||
networkName = "udp4"
|
||||
}
|
||||
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
udpMuxParams := ice.UDPMuxParams{
|
||||
UDPConn: e.udpMuxConn,
|
||||
Net: transportNet,
|
||||
}
|
||||
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
|
||||
|
||||
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
|
||||
}
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
||||
|
||||
if e.dnsServer == nil {
|
||||
@@ -496,7 +503,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
||||
IP: e.config.WgAddr,
|
||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: conf.GetFqdn(),
|
||||
})
|
||||
|
||||
@@ -822,9 +829,10 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
||||
ProxyConfig: proxyConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover)
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1006,12 +1014,6 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMuxSrflx != nil {
|
||||
if err := e.udpMuxSrflx.Close(); err != nil {
|
||||
log.Debugf("close server reflexive udp mux: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMuxConn != nil {
|
||||
if err := e.udpMuxConn.Close(); err != nil {
|
||||
log.Debugf("close udp mux connection: %v", err)
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet()
|
||||
return stdnet.NewNet(e.config.IFaceBlackList)
|
||||
}
|
||||
|
||||
@@ -3,5 +3,5 @@ package internal
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(e.config.IFaceDiscover)
|
||||
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -72,7 +74,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -206,12 +208,24 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
conn, err := net.ListenUDP("udp4", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -390,7 +404,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -548,8 +562,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
@@ -713,8 +731,12 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
@@ -978,7 +1000,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil
|
||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||
}
|
||||
|
||||
func startSignal() (*grpc.Server, string, error) {
|
||||
|
||||
13
client/internal/mobile_dependency.go
Normal file
13
client/internal/mobile_dependency.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// MobileDependency collect all dependencies for mobile platform
|
||||
type MobileDependency struct {
|
||||
TunAdapter iface.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
Routes []string
|
||||
}
|
||||
29
client/internal/mobile_dependency_android.go
Normal file
29
client/internal/mobile_dependency_android.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
)
|
||||
|
||||
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
|
||||
md := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: ifaceDiscover,
|
||||
}
|
||||
err := md.readMap(mgmClient)
|
||||
return md, err
|
||||
}
|
||||
|
||||
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
|
||||
routes, err := mgmClient.GetRoutes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Routes = make([]string, len(routes))
|
||||
for i, r := range routes {
|
||||
d.Routes[i] = r.GetNetwork()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
13
client/internal/mobile_dependency_nonandroid.go
Normal file
13
client/internal/mobile_dependency_nonandroid.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
)
|
||||
|
||||
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
|
||||
return MobileDependency{}, nil
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -46,6 +45,9 @@ type ConnConfig struct {
|
||||
LocalWgPort int
|
||||
|
||||
NATExternalIPs []string
|
||||
|
||||
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
|
||||
UserspaceBind bool
|
||||
}
|
||||
|
||||
// OfferAnswer represents a session establishment offer or answer
|
||||
@@ -95,7 +97,7 @@ type Conn struct {
|
||||
meta meta
|
||||
|
||||
adapter iface.TunAdapter
|
||||
iFaceDiscover stdnet.IFaceDiscover
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
}
|
||||
|
||||
// meta holds meta information about a connection
|
||||
@@ -121,7 +123,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*Conn, error) {
|
||||
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
||||
return &Conn{
|
||||
config: config,
|
||||
mu: sync.Mutex{},
|
||||
@@ -136,32 +138,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter
|
||||
}, nil
|
||||
}
|
||||
|
||||
// interfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
|
||||
// to avoid building tunnel over them
|
||||
func interfaceFilter(blackList []string) func(string) bool {
|
||||
|
||||
return func(iFace string) bool {
|
||||
for _, s := range blackList {
|
||||
if strings.HasPrefix(iFace, s) {
|
||||
log.Debugf("ignoring interface %s - it is not allowed", iFace)
|
||||
return false
|
||||
}
|
||||
}
|
||||
// look for unlisted WireGuard interfaces
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
log.Debugf("trying to create a wgctrl client failed with: %v", err)
|
||||
return true
|
||||
}
|
||||
defer func() {
|
||||
_ = wg.Close()
|
||||
}()
|
||||
|
||||
_, err = wg.Device(iFace)
|
||||
return err != nil
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) reCreateAgent() error {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
@@ -171,7 +147,7 @@ func (conn *Conn) reCreateAgent() error {
|
||||
var err error
|
||||
transportNet, err := conn.newStdNet()
|
||||
if err != nil {
|
||||
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
agentConfig := &ice.AgentConfig{
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
@@ -179,7 +155,7 @@ func (conn *Conn) reCreateAgent() error {
|
||||
Urls: conn.config.StunTurn,
|
||||
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
|
||||
FailedTimeout: &failedTimeout,
|
||||
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList),
|
||||
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
||||
UDPMux: conn.config.UDPMux,
|
||||
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
||||
NAT1To1IPs: conn.config.NATExternalIPs,
|
||||
@@ -319,7 +295,7 @@ func (conn *Conn) Open() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if conn.proxy.Type() == proxy.TypeNoProxy {
|
||||
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
|
||||
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
||||
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
||||
// direct Wireguard connection
|
||||
@@ -341,29 +317,62 @@ func (conn *Conn) Open() error {
|
||||
|
||||
// useProxy determines whether a direct connection (without a go proxy) is possible
|
||||
//
|
||||
// There are 2 cases:
|
||||
// There are 3 cases:
|
||||
//
|
||||
// * When neither candidate is from hard nat and one of the peers has a public IP
|
||||
//
|
||||
// * both peers are in the same private network
|
||||
//
|
||||
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
|
||||
//
|
||||
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
|
||||
func shouldUseProxy(pair *ice.CandidatePair) bool {
|
||||
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
|
||||
|
||||
if !isRelayCandidate(pair.Local) && userspaceBind {
|
||||
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
|
||||
return false
|
||||
}
|
||||
|
||||
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
|
||||
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
|
||||
return false
|
||||
}
|
||||
|
||||
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
|
||||
log.Debugf("shouldn't use proxy because the remote peer is not behind a hard NAT and the local one has a public IP")
|
||||
return false
|
||||
}
|
||||
|
||||
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) {
|
||||
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
|
||||
log.Debugf("shouldn't use proxy because peers are in the same private /16 network")
|
||||
return false
|
||||
}
|
||||
|
||||
if (isPeerReflexiveCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) ||
|
||||
isHostCandidateWithPrivateIP(pair.Local) && isPeerReflexiveCandidateWithPrivateIP(pair.Remote)) && isSameNetworkPrefix(pair) {
|
||||
log.Debugf("shouldn't use proxy because peers are in the same private /16 network and one peer is peer reflexive")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
|
||||
|
||||
localIP := net.ParseIP(pair.Local.Address())
|
||||
remoteIP := net.ParseIP(pair.Remote.Address())
|
||||
if localIP == nil || remoteIP == nil {
|
||||
return false
|
||||
}
|
||||
// only consider /16 networks
|
||||
mask := net.IPMask{255, 255, 0, 0}
|
||||
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
|
||||
}
|
||||
|
||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
func isHardNATCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
|
||||
}
|
||||
@@ -376,9 +385,13 @@ func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
|
||||
}
|
||||
|
||||
func isPeerReflexiveCandidateWithPrivateIP(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypePeerReflexive && !isPublicIP(candidate.Address())
|
||||
}
|
||||
|
||||
func isPublicIP(address string) bool {
|
||||
ip := net.ParseIP(address)
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
|
||||
if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -412,7 +425,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
peerState.Relayed = true
|
||||
}
|
||||
peerState.Direct = p.Type() == proxy.TypeNoProxy
|
||||
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
|
||||
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -423,8 +436,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
}
|
||||
|
||||
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
||||
|
||||
useProxy := shouldUseProxy(pair)
|
||||
useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
|
||||
localDirectMode := !useProxy
|
||||
remoteDirectMode := localDirectMode
|
||||
|
||||
@@ -434,13 +446,16 @@ func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgP
|
||||
remoteDirectMode = conn.receiveRemoteDirectMode()
|
||||
}
|
||||
|
||||
if conn.config.UserspaceBind && localDirectMode {
|
||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
||||
}
|
||||
|
||||
if localDirectMode && remoteDirectMode {
|
||||
log.Debugf("using WireGuard direct mode with peer %s", conn.config.Key)
|
||||
return proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
||||
return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
||||
}
|
||||
|
||||
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
|
||||
return proxy.NewWireguardProxy(conn.config.ProxyConfig)
|
||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
||||
}
|
||||
|
||||
func (conn *Conn) sendLocalDirectMode(localMode bool) {
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/pion/ice/v2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -28,7 +30,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
|
||||
"Tailscale", "tailscale"}
|
||||
|
||||
filter := interfaceFilter(ignore)
|
||||
filter := stdnet.InterfaceFilter(ignore)
|
||||
|
||||
for _, s := range ignore {
|
||||
assert.Equal(t, filter(s), false)
|
||||
@@ -208,6 +210,7 @@ func TestConn_ShouldUseProxy(t *testing.T) {
|
||||
return ice.CandidateTypeHost
|
||||
},
|
||||
}
|
||||
|
||||
srflxCandidate := &mockICECandidate{
|
||||
AddressFunc: func() string {
|
||||
return "1.1.1.1"
|
||||
@@ -320,11 +323,47 @@ func TestConn_ShouldUseProxy(t *testing.T) {
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Don't Use Proxy When Both Candidates are in private network and one is peer reflexive",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: &mockICECandidate{AddressFunc: func() string {
|
||||
return "10.16.102.168"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypeHost
|
||||
}},
|
||||
Remote: &mockICECandidate{AddressFunc: func() string {
|
||||
return "10.16.101.96"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypePeerReflexive
|
||||
}},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Should Use Proxy When Both Candidates are in private network and both are peer reflexive",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: &mockICECandidate{AddressFunc: func() string {
|
||||
return "10.16.102.168"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypePeerReflexive
|
||||
}},
|
||||
Remote: &mockICECandidate{AddressFunc: func() string {
|
||||
return "10.16.101.96"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypePeerReflexive
|
||||
}},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := shouldUseProxy(testCase.candatePair)
|
||||
result := shouldUseProxy(testCase.candatePair, false)
|
||||
if result != testCase.expected {
|
||||
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
|
||||
}
|
||||
@@ -365,7 +404,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
|
||||
},
|
||||
inputDirectModeSupport: true,
|
||||
inputRemoteModeMessage: true,
|
||||
expected: proxy.TypeWireguard,
|
||||
expected: proxy.TypeWireGuard,
|
||||
},
|
||||
{
|
||||
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
|
||||
@@ -375,7 +414,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
|
||||
},
|
||||
inputDirectModeSupport: true,
|
||||
inputRemoteModeMessage: false,
|
||||
expected: proxy.TypeWireguard,
|
||||
expected: proxy.TypeWireGuard,
|
||||
},
|
||||
{
|
||||
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
|
||||
@@ -385,7 +424,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
|
||||
},
|
||||
inputDirectModeSupport: false,
|
||||
inputRemoteModeMessage: false,
|
||||
expected: proxy.TypeWireguard,
|
||||
expected: proxy.TypeWireGuard,
|
||||
},
|
||||
{
|
||||
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
|
||||
@@ -395,7 +434,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
|
||||
},
|
||||
inputDirectModeSupport: false,
|
||||
inputRemoteModeMessage: false,
|
||||
expected: proxy.TypeNoProxy,
|
||||
expected: proxy.TypeDirectNoProxy,
|
||||
},
|
||||
{
|
||||
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
|
||||
@@ -405,7 +444,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
|
||||
},
|
||||
inputDirectModeSupport: true,
|
||||
inputRemoteModeMessage: true,
|
||||
expected: proxy.TypeNoProxy,
|
||||
expected: proxy.TypeDirectNoProxy,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
|
||||
@@ -78,6 +78,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||
defer d.mux.Unlock()
|
||||
d.offlinePeers = make([]State, len(replacement))
|
||||
copy(d.offlinePeers, replacement)
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
// AddPeer adds peer to Daemon status map
|
||||
@@ -308,7 +309,7 @@ func (d *Status) onConnectionChanged() {
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(len(d.peers))
|
||||
d.notifier.peerListChanged(len(d.peers) + len(d.offlinePeers))
|
||||
}
|
||||
|
||||
func (d *Status) notifyAddressChanged() {
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet()
|
||||
return stdnet.NewNet(conn.config.InterfaceBlackList)
|
||||
}
|
||||
|
||||
@@ -3,5 +3,5 @@ package peer
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(conn.iFaceDiscover)
|
||||
return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
|
||||
}
|
||||
|
||||
57
client/internal/proxy/direct.go
Normal file
57
client/internal/proxy/direct.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
)
|
||||
|
||||
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
|
||||
// This is possible in either of these cases:
|
||||
// - peers are in the same local network
|
||||
// - one of the peers has a public static IP (host)
|
||||
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
|
||||
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
|
||||
type DirectNoProxy struct {
|
||||
config Config
|
||||
// RemoteWgListenPort is a WireGuard port of a remote peer.
|
||||
// It is used instead of the hardcoded 51820 port.
|
||||
RemoteWgListenPort int
|
||||
}
|
||||
|
||||
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
|
||||
func NewDirectNoProxy(config Config, remoteWgPort int) *DirectNoProxy {
|
||||
return &DirectNoProxy{config: config, RemoteWgListenPort: remoteWgPort}
|
||||
}
|
||||
|
||||
// Close removes peer from the WireGuard interface
|
||||
func (p *DirectNoProxy) Close() error {
|
||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start just updates WireGuard peer with the remote IP and default WireGuard port
|
||||
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
|
||||
|
||||
log.Debugf("using DirectNoProxy while connecting to peer %s", p.config.RemoteKey)
|
||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr.Port = p.RemoteWgListenPort
|
||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
addr, p.config.PreSharedKey)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Type returns the type of this proxy
|
||||
func (p *DirectNoProxy) Type() Type {
|
||||
return TypeDirectNoProxy
|
||||
}
|
||||
@@ -5,24 +5,18 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// NoProxy is used when there is no need for a proxy between ICE and Wireguard.
|
||||
// This is possible in either of these cases:
|
||||
// - peers are in the same local network
|
||||
// - one of the peers has a public static IP (host)
|
||||
// NoProxy will just update remote peer with a remote host and fixed Wireguard port (r.g. 51820).
|
||||
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
|
||||
// NoProxy is used just to configure WireGuard without any local proxy in between.
|
||||
// Used when the WireGuard interface is userspace and uses bind.ICEBind
|
||||
type NoProxy struct {
|
||||
config Config
|
||||
// RemoteWgListenPort is a WireGuard port of a remote peer.
|
||||
// It is used instead of the hardcoded 51820 port.
|
||||
RemoteWgListenPort int
|
||||
}
|
||||
|
||||
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port
|
||||
func NewNoProxy(config Config, remoteWgPort int) *NoProxy {
|
||||
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort}
|
||||
// NewNoProxy creates a new NoProxy with a provided config
|
||||
func NewNoProxy(config Config) *NoProxy {
|
||||
return &NoProxy{config: config}
|
||||
}
|
||||
|
||||
// Close removes peer from the WireGuard interface
|
||||
func (p *NoProxy) Close() error {
|
||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||
if err != nil {
|
||||
@@ -31,23 +25,16 @@ func (p *NoProxy) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start just updates Wireguard peer with the remote IP and default Wireguard port
|
||||
// Start just updates WireGuard peer with the remote address
|
||||
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
||||
|
||||
log.Debugf("using NoProxy while connecting to peer %s", p.config.RemoteKey)
|
||||
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
|
||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr.Port = p.RemoteWgListenPort
|
||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
addr, p.config.PreSharedKey)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *NoProxy) Type() Type {
|
||||
|
||||
@@ -13,9 +13,10 @@ const DefaultWgKeepAlive = 25 * time.Second
|
||||
type Type string
|
||||
|
||||
const (
|
||||
TypeNoProxy Type = "NoProxy"
|
||||
TypeWireguard Type = "Wireguard"
|
||||
TypeDummy Type = "Dummy"
|
||||
TypeDirectNoProxy Type = "DirectNoProxy"
|
||||
TypeWireGuard Type = "WireGuard"
|
||||
TypeDummy Type = "Dummy"
|
||||
TypeNoProxy Type = "NoProxy"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// WireguardProxy proxies
|
||||
type WireguardProxy struct {
|
||||
// WireGuardProxy proxies
|
||||
type WireGuardProxy struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
@@ -17,13 +17,13 @@ type WireguardProxy struct {
|
||||
localConn net.Conn
|
||||
}
|
||||
|
||||
func NewWireguardProxy(config Config) *WireguardProxy {
|
||||
p := &WireguardProxy{config: config}
|
||||
func NewWireGuardProxy(config Config) *WireGuardProxy {
|
||||
p := &WireGuardProxy{config: config}
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *WireguardProxy) updateEndpoint() error {
|
||||
func (p *WireGuardProxy) updateEndpoint() error {
|
||||
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -38,7 +38,7 @@ func (p *WireguardProxy) updateEndpoint() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WireguardProxy) Start(remoteConn net.Conn) error {
|
||||
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
@@ -60,7 +60,7 @@ func (p *WireguardProxy) Start(remoteConn net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WireguardProxy) Close() error {
|
||||
func (p *WireGuardProxy) Close() error {
|
||||
p.cancel()
|
||||
if c := p.localConn; c != nil {
|
||||
err := p.localConn.Close()
|
||||
@@ -77,7 +77,7 @@ func (p *WireguardProxy) Close() error {
|
||||
|
||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||
// blocks
|
||||
func (p *WireguardProxy) proxyToRemote() {
|
||||
func (p *WireGuardProxy) proxyToRemote() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
@@ -101,7 +101,7 @@ func (p *WireguardProxy) proxyToRemote() {
|
||||
|
||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||
// blocks
|
||||
func (p *WireguardProxy) proxyToLocal() {
|
||||
func (p *WireGuardProxy) proxyToLocal() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
@@ -123,6 +123,6 @@ func (p *WireguardProxy) proxyToLocal() {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WireguardProxy) Type() Type {
|
||||
return TypeWireguard
|
||||
func (p *WireGuardProxy) Type() Type {
|
||||
return TypeWireGuard
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
import "github.com/google/nftables"
|
||||
|
||||
const (
|
||||
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
@@ -1,9 +1,130 @@
|
||||
package routemanager
|
||||
|
||||
import "github.com/netbirdio/netbird/route"
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||
Stop()
|
||||
}
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
serverRouter *serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
}
|
||||
|
||||
// NewManager returns a new route manager
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
return &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
serverRouter: newServerRouter(ctx, wgInterface),
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop() {
|
||||
m.stop()
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||
newServerRoutesMap := make(map[string]*route.Route)
|
||||
ownNetworkIDs := make(map[string]bool)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[networkID] = true
|
||||
// only linux is supported for now
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||
continue
|
||||
}
|
||||
newServerRoutesMap[newRoute.ID] = newRoute
|
||||
}
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if !ownNetworkIDs[networkID] {
|
||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||
// we skip this route management
|
||||
if newRoute.Network.Bits() < 7 {
|
||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||
version.NetbirdVersion(), newRoute.Network)
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||
}
|
||||
}
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
update := routesUpdate{
|
||||
updateSerial: updateSerial,
|
||||
routes: routes,
|
||||
}
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// DefaultManager dummy router manager for Android
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
serverRouter *serverRouter
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
// NewManager returns a new dummy route manager what doing nothing
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||
return &DefaultManager{}
|
||||
}
|
||||
|
||||
// UpdateRoutes ...
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop ...
|
||||
func (m *DefaultManager) Stop() {
|
||||
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
serverRoutes map[string]*route.Route
|
||||
serverRouter *serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
}
|
||||
|
||||
// NewManager returns a new route manager
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
return &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
serverRoutes: make(map[string]*route.Route),
|
||||
serverRouter: &serverRouter{
|
||||
routes: make(map[string]*route.Route),
|
||||
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
||||
firewall: NewFirewall(ctx),
|
||||
},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop() {
|
||||
m.stop()
|
||||
m.serverRouter.firewall.CleanRoutingRules()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
update := routesUpdate{
|
||||
updateSerial: updateSerial,
|
||||
routes: routes,
|
||||
}
|
||||
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
|
||||
if len(routesMap) > 0 {
|
||||
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for routeID := range m.serverRoutes {
|
||||
update, found := routesMap[routeID]
|
||||
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range serverRoutesToRemove {
|
||||
oldRoute := m.serverRoutes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.serverRoutes, routeID)
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.serverRoutes[id]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.serverRoutes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.serverRoutes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||
newServerRoutesMap := make(map[string]*route.Route)
|
||||
ownNetworkIDs := make(map[string]bool)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[networkID] = true
|
||||
// only linux is supported for now
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||
continue
|
||||
}
|
||||
newServerRoutesMap[newRoute.ID] = newRoute
|
||||
}
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if !ownNetworkIDs[networkID] {
|
||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||
// we skip this route management
|
||||
if newRoute.Network.Bits() < 7 {
|
||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||
version.NetbirdVersion(), newRoute.Network)
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||
}
|
||||
}
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
|
||||
err := m.updateServerRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -391,7 +392,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil)
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@@ -414,7 +420,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||
|
||||
if testCase.shouldCheckServerRoutes {
|
||||
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
import "github.com/google/nftables"
|
||||
|
||||
const (
|
||||
nftablesTable = "netbird-rt"
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
24
client/internal/routemanager/router_pair.go
Normal file
24
client/internal/routemanager/router_pair.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type routerPair struct {
|
||||
ID string
|
||||
source string
|
||||
destination string
|
||||
masquerade bool
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) routerPair {
|
||||
parsed := netip.MustParsePrefix(source).Masked()
|
||||
return routerPair{
|
||||
ID: route.ID,
|
||||
source: parsed.String(),
|
||||
destination: route.Network.Masked().String(),
|
||||
masquerade: route.Masquerade,
|
||||
}
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
routes map[string]*route.Route
|
||||
// best effort to keep net forward configuration as it was
|
||||
netForwardHistoryEnabled bool
|
||||
mux sync.Mutex
|
||||
firewall firewallManager
|
||||
}
|
||||
|
||||
type routerPair struct {
|
||||
ID string
|
||||
source string
|
||||
destination string
|
||||
masquerade bool
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) routerPair {
|
||||
parsed := netip.MustParsePrefix(source).Masked()
|
||||
return routerPair{
|
||||
ID: route.ID,
|
||||
source: parsed.String(),
|
||||
destination: route.Network.Masked().String(),
|
||||
masquerade: route.Masquerade,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.serverRouter.mux.Lock()
|
||||
defer m.serverRouter.mux.Unlock()
|
||||
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(m.serverRouter.routes, route.ID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.serverRouter.mux.Lock()
|
||||
defer m.serverRouter.mux.Unlock()
|
||||
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.serverRouter.routes[route.ID] = route
|
||||
return nil
|
||||
}
|
||||
}
|
||||
21
client/internal/routemanager/server_android.go
Normal file
21
client/internal/routemanager/server_android.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
|
||||
return &serverRouter{}
|
||||
}
|
||||
|
||||
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *serverRouter) cleanUp() {}
|
||||
120
client/internal/routemanager/server_nonandroid.go
Normal file
120
client/internal/routemanager/server_nonandroid.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
routes map[string]*route.Route
|
||||
firewall firewallManager
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
|
||||
return &serverRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[string]*route.Route),
|
||||
firewall: NewFirewall(ctx),
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
|
||||
if len(routesMap) > 0 {
|
||||
err := m.firewall.RestoreOrCreateContainers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for routeID := range m.routes {
|
||||
update, found := routesMap[routeID]
|
||||
if !found || !update.IsEqual(m.routes[routeID]) {
|
||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range serverRoutesToRemove {
|
||||
oldRoute := m.routes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.routes, routeID)
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.routes[id]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.routes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.routes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(m.routes, route.ID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.routes[route.ID] = route
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverRouter) cleanUp() {
|
||||
m.firewall.CleanRoutingRules()
|
||||
}
|
||||
13
client/internal/routemanager/systemops_android.go
Normal file
13
client/internal/routemanager/systemops_android.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"github.com/vishvananda/netlink"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||
@@ -62,12 +65,3 @@ func enableIPForwarding() error {
|
||||
err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
|
||||
return err
|
||||
}
|
||||
|
||||
func isNetForwardHistoryEnabled() bool {
|
||||
out, err := os.ReadFile(ipv4ForwardingPath)
|
||||
if err != nil {
|
||||
// todo
|
||||
panic(err)
|
||||
}
|
||||
return string(out) == "1"
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var errRouteNotFound = fmt.Errorf("route not found")
|
||||
@@ -3,6 +3,7 @@ package routemanager
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -32,7 +33,11 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||
@@ -34,8 +35,3 @@ func enableIPForwarding() error {
|
||||
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNetForwardHistoryEnabled() bool {
|
||||
log.Infof("check netforward history is not implemented on %s", runtime.GOOS)
|
||||
return false
|
||||
}
|
||||
|
||||
14
client/internal/stdnet/discover.go
Normal file
14
client/internal/stdnet/discover.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package stdnet
|
||||
|
||||
import "github.com/pion/transport/v2"
|
||||
|
||||
// ExternalIFaceDiscover provide an option for external services (mobile)
|
||||
// to collect network interface information
|
||||
type ExternalIFaceDiscover interface {
|
||||
// IFaces return with the description of the interfaces
|
||||
IFaces() (string, error)
|
||||
}
|
||||
|
||||
type iFaceDiscover interface {
|
||||
iFaces() ([]*transport.Interface, error)
|
||||
}
|
||||
98
client/internal/stdnet/discover_mobile.go
Normal file
98
client/internal/stdnet/discover_mobile.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type mobileIFaceDiscover struct {
|
||||
externalDiscover ExternalIFaceDiscover
|
||||
}
|
||||
|
||||
func newMobileIFaceDiscover(externalDiscover ExternalIFaceDiscover) *mobileIFaceDiscover {
|
||||
return &mobileIFaceDiscover{
|
||||
externalDiscover: externalDiscover,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mobileIFaceDiscover) iFaces() ([]*transport.Interface, error) {
|
||||
ifacesString, err := m.externalDiscover.IFaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
interfaces := m.parseInterfacesString(ifacesString)
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func (m *mobileIFaceDiscover) parseInterfacesString(interfaces string) []*transport.Interface {
|
||||
ifs := []*transport.Interface{}
|
||||
|
||||
for _, iface := range strings.Split(interfaces, "\n") {
|
||||
if strings.TrimSpace(iface) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Split(iface, "|")
|
||||
if len(fields) != 2 {
|
||||
log.Warnf("parseInterfacesString: unable to split %q", iface)
|
||||
continue
|
||||
}
|
||||
|
||||
var name string
|
||||
var index, mtu int
|
||||
var up, broadcast, loopback, pointToPoint, multicast bool
|
||||
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
|
||||
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
|
||||
if err != nil {
|
||||
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
|
||||
continue
|
||||
}
|
||||
|
||||
newIf := net.Interface{
|
||||
Name: name,
|
||||
Index: index,
|
||||
MTU: mtu,
|
||||
}
|
||||
if up {
|
||||
newIf.Flags |= net.FlagUp
|
||||
}
|
||||
if broadcast {
|
||||
newIf.Flags |= net.FlagBroadcast
|
||||
}
|
||||
if loopback {
|
||||
newIf.Flags |= net.FlagLoopback
|
||||
}
|
||||
if pointToPoint {
|
||||
newIf.Flags |= net.FlagPointToPoint
|
||||
}
|
||||
if multicast {
|
||||
newIf.Flags |= net.FlagMulticast
|
||||
}
|
||||
|
||||
ifc := transport.NewInterface(newIf)
|
||||
|
||||
addrs := strings.Trim(fields[1], " \n")
|
||||
foundAddress := false
|
||||
for _, addr := range strings.Split(addrs, " ") {
|
||||
if strings.Contains(addr, "%") {
|
||||
continue
|
||||
}
|
||||
ip, ipNet, err := net.ParseCIDR(addr)
|
||||
if err != nil {
|
||||
log.Warnf("%s", err)
|
||||
continue
|
||||
}
|
||||
ipNet.IP = ip
|
||||
ifc.AddAddress(ipNet)
|
||||
foundAddress = true
|
||||
}
|
||||
if foundAddress {
|
||||
ifs = append(ifs, ifc)
|
||||
}
|
||||
}
|
||||
return ifs
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package stdnet
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Test_parseInterfacesString(t *testing.T) {
|
||||
@@ -20,6 +22,7 @@ func Test_parseInterfacesString(t *testing.T) {
|
||||
{"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"},
|
||||
{"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"},
|
||||
{"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"},
|
||||
{"rmnet_data2", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2%rmnet2/64"},
|
||||
}
|
||||
|
||||
var exampleString string
|
||||
@@ -35,11 +38,13 @@ func Test_parseInterfacesString(t *testing.T) {
|
||||
d.multicast,
|
||||
d.addr)
|
||||
}
|
||||
nets := parseInterfacesString(exampleString)
|
||||
d := mobileIFaceDiscover{}
|
||||
nets := d.parseInterfacesString(exampleString)
|
||||
if len(nets) == 0 {
|
||||
t.Fatalf("failed to parse interfaces")
|
||||
}
|
||||
|
||||
log.Printf("%d", len(nets))
|
||||
for i, net := range nets {
|
||||
if net.MTU != testData[i].mtu {
|
||||
t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu)
|
||||
@@ -58,7 +63,7 @@ func Test_parseInterfacesString(t *testing.T) {
|
||||
if len(addr) == 0 {
|
||||
t.Errorf("invalid address parsing")
|
||||
}
|
||||
|
||||
log.Printf("%v", addr)
|
||||
if addr[0].String() != testData[i].addr {
|
||||
t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr)
|
||||
}
|
||||
36
client/internal/stdnet/discover_pion.go
Normal file
36
client/internal/stdnet/discover_pion.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
)
|
||||
|
||||
type pionDiscover struct {
|
||||
}
|
||||
|
||||
func (d pionDiscover) iFaces() ([]*transport.Interface, error) {
|
||||
ifs := []*transport.Interface{}
|
||||
|
||||
oifs, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, oif := range oifs {
|
||||
ifc := transport.NewInterface(oif)
|
||||
|
||||
addrs, err := oif.Addrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ifc.AddAddress(addr)
|
||||
}
|
||||
|
||||
ifs = append(ifs, ifc)
|
||||
}
|
||||
|
||||
return ifs, nil
|
||||
}
|
||||
40
client/internal/stdnet/filter.go
Normal file
40
client/internal/stdnet/filter.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
)
|
||||
|
||||
// InterfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
|
||||
// to avoid building tunnel over them.
|
||||
func InterfaceFilter(disallowList []string) func(string) bool {
|
||||
|
||||
return func(iFace string) bool {
|
||||
|
||||
if strings.HasPrefix(iFace, "lo") {
|
||||
// hardcoded loopback check to support already installed agents
|
||||
return false
|
||||
}
|
||||
|
||||
for _, s := range disallowList {
|
||||
if strings.HasPrefix(iFace, s) {
|
||||
log.Debugf("ignoring interface %s - it is not allowed", iFace)
|
||||
return false
|
||||
}
|
||||
}
|
||||
// look for unlisted WireGuard interfaces
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
log.Debugf("trying to create a wgctrl client failed with: %v", err)
|
||||
return true
|
||||
}
|
||||
defer func() {
|
||||
_ = wg.Close()
|
||||
}()
|
||||
|
||||
_, err = wg.Device(iFace)
|
||||
return err != nil
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package stdnet
|
||||
|
||||
// IFaceDiscover provide an option for external services (mobile)
|
||||
// to collect network interface information
|
||||
type IFaceDiscover interface {
|
||||
// IFaces return with the description of the interfaces
|
||||
IFaces() (string, error)
|
||||
}
|
||||
@@ -5,37 +5,50 @@ package stdnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Net is an implementation of the net.Net interface
|
||||
// based on functions of the standard net package.
|
||||
type Net struct {
|
||||
stdnet.Net
|
||||
interfaces []*transport.Interface
|
||||
interfaces []*transport.Interface
|
||||
iFaceDiscover iFaceDiscover
|
||||
// interfaceFilter should return true if the given interfaceName is allowed
|
||||
interfaceFilter func(interfaceName string) bool
|
||||
}
|
||||
|
||||
// NewNetWithDiscover creates a new StdNet instance.
|
||||
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||
n := &Net{
|
||||
iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover),
|
||||
interfaceFilter: InterfaceFilter(disallowList),
|
||||
}
|
||||
return n, n.UpdateInterfaces()
|
||||
}
|
||||
|
||||
// NewNet creates a new StdNet instance.
|
||||
func NewNet(iFaceDiscover IFaceDiscover) (*Net, error) {
|
||||
n := &Net{}
|
||||
|
||||
return n, n.UpdateInterfaces(iFaceDiscover)
|
||||
func NewNet(disallowList []string) (*Net, error) {
|
||||
n := &Net{
|
||||
iFaceDiscover: pionDiscover{},
|
||||
interfaceFilter: InterfaceFilter(disallowList),
|
||||
}
|
||||
return n, n.UpdateInterfaces()
|
||||
}
|
||||
|
||||
// UpdateInterfaces updates the internal list of network interfaces
|
||||
// and associated addresses.
|
||||
func (n *Net) UpdateInterfaces(iFaceDiscover IFaceDiscover) error {
|
||||
ifacesString, err := iFaceDiscover.IFaces()
|
||||
// and associated addresses filtering them by name.
|
||||
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
||||
// wasn't specified.
|
||||
func (n *Net) UpdateInterfaces() (err error) {
|
||||
allIfaces, err := n.iFaceDiscover.iFaces()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.interfaces = parseInterfacesString(ifacesString)
|
||||
return err
|
||||
n.interfaces = n.filterInterfaces(allIfaces)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Interfaces returns a slice of interfaces which are available on the
|
||||
@@ -70,68 +83,15 @@ func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
|
||||
return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name)
|
||||
}
|
||||
|
||||
func parseInterfacesString(interfaces string) []*transport.Interface {
|
||||
ifs := []*transport.Interface{}
|
||||
|
||||
for _, iface := range strings.Split(interfaces, "\n") {
|
||||
if strings.TrimSpace(iface) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Split(iface, "|")
|
||||
if len(fields) != 2 {
|
||||
log.Warnf("parseInterfacesString: unable to split %q", iface)
|
||||
continue
|
||||
}
|
||||
|
||||
var name string
|
||||
var index, mtu int
|
||||
var up, broadcast, loopback, pointToPoint, multicast bool
|
||||
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
|
||||
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
|
||||
if err != nil {
|
||||
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
|
||||
continue
|
||||
}
|
||||
|
||||
newIf := net.Interface{
|
||||
Name: name,
|
||||
Index: index,
|
||||
MTU: mtu,
|
||||
}
|
||||
if up {
|
||||
newIf.Flags |= net.FlagUp
|
||||
}
|
||||
if broadcast {
|
||||
newIf.Flags |= net.FlagBroadcast
|
||||
}
|
||||
if loopback {
|
||||
newIf.Flags |= net.FlagLoopback
|
||||
}
|
||||
if pointToPoint {
|
||||
newIf.Flags |= net.FlagPointToPoint
|
||||
}
|
||||
if multicast {
|
||||
newIf.Flags |= net.FlagMulticast
|
||||
}
|
||||
|
||||
ifc := transport.NewInterface(newIf)
|
||||
|
||||
addrs := strings.Trim(fields[1], " \n")
|
||||
foundAddress := false
|
||||
for _, addr := range strings.Split(addrs, " ") {
|
||||
ip, ipNet, err := net.ParseCIDR(addr)
|
||||
if err != nil {
|
||||
log.Warnf("%s", err)
|
||||
continue
|
||||
}
|
||||
ipNet.IP = ip
|
||||
ifc.AddAddress(ipNet)
|
||||
foundAddress = true
|
||||
}
|
||||
if foundAddress {
|
||||
ifs = append(ifs, ifc)
|
||||
func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.Interface {
|
||||
if n.interfaceFilter == nil {
|
||||
return interfaces
|
||||
}
|
||||
result := []*transport.Interface{}
|
||||
for _, iface := range interfaces {
|
||||
if n.interfaceFilter(iface.Name) {
|
||||
result = append(result, iface)
|
||||
}
|
||||
}
|
||||
return ifs
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -6,4 +6,4 @@
|
||||
#define EXPAND(x) STRINGIZE(x)
|
||||
CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
|
||||
7 ICON ui/netbird.ico
|
||||
wireguard.dll RCDATA wireguard.dll
|
||||
wintun.dll RCDATA wintun.dll
|
||||
|
||||
@@ -43,6 +43,15 @@ func extractUserAgent(ctx context.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractDeviceName extracts device name from context or returns the default system name
|
||||
func extractDeviceName(ctx context.Context, defaultName string) string {
|
||||
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
||||
if !ok {
|
||||
return defaultName
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// GetDesktopUIUserAgent returns the Desktop ui user agent
|
||||
func GetDesktopUIUserAgent() string {
|
||||
return "netbird-desktop-ui/" + version.NetbirdVersion()
|
||||
|
||||
@@ -24,21 +24,13 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname = extractDeviceName(ctx)
|
||||
gio.Hostname = extractDeviceName(ctx, "android")
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
return gio
|
||||
}
|
||||
|
||||
func extractDeviceName(ctx context.Context) string {
|
||||
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
||||
if !ok {
|
||||
return "android"
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func uname() []string {
|
||||
res := run("/system/bin/uname", "-a")
|
||||
return strings.Split(res, " ")
|
||||
|
||||
@@ -32,7 +32,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
swVersion = []byte(release)
|
||||
}
|
||||
gio := &Info{Kernel: sysName, OSVersion: strings.TrimSpace(string(swVersion)), Core: release, Platform: machine, OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname, _ = os.Hostname()
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
osStr = strings.Replace(osStr, "\r\n", "", -1)
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname, _ = os.Hostname()
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
|
||||
@@ -50,7 +50,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
osName = osInfo[3]
|
||||
}
|
||||
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: osInfo[2], OS: osName, OSVersion: osVer, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname, _ = os.Hostname()
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
|
||||
@@ -24,3 +24,12 @@ func Test_UIVersion(t *testing.T) {
|
||||
got := GetInfo(ctx)
|
||||
assert.Equal(t, want, got.UIVersion)
|
||||
}
|
||||
|
||||
func Test_CustomHostname(t *testing.T) {
|
||||
// nolint
|
||||
ctx := context.WithValue(context.Background(), DeviceNameCtxKey, "custom-host")
|
||||
want := "custom-host"
|
||||
|
||||
got := GetInfo(ctx)
|
||||
assert.Equal(t, want, got.Hostname)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,8 @@ import (
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
ver := getOSVersion()
|
||||
gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||
gio.Hostname, _ = os.Hostname()
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
ldir=$PWD
|
||||
tmp_dir_path=$ldir/.distfiles
|
||||
winnt=wireguard-nt.zip
|
||||
download_file_path=$tmp_dir_path/$winnt
|
||||
download_url=https://download.wireguard.com/wireguard-nt/wireguard-nt-0.10.1.zip
|
||||
download_sha=772c0b1463d8d2212716f43f06f4594d880dea4f735165bd68e388fc41b81605
|
||||
|
||||
function resources_windows(){
|
||||
cmd=$1
|
||||
arch=$2
|
||||
out=$3
|
||||
docker run -i --rm -v $PWD:$PWD -w $PWD mstorsjo/llvm-mingw:latest $cmd -O coff -c 65001 -I $tmp_dir_path/wireguard-nt/bin/$arch -i resources.rc -o $out
|
||||
}
|
||||
|
||||
mkdir -p $tmp_dir_path
|
||||
curl -L#o $download_file_path.unverified $download_url
|
||||
echo "$download_sha $download_file_path.unverified" | sha256sum -c
|
||||
mv $download_file_path.unverified $download_file_path
|
||||
|
||||
mkdir -p .deps
|
||||
unzip $download_file_path -d $tmp_dir_path
|
||||
|
||||
resources_windows i686-w64-mingw32-windres x86 resources_windows_386.syso
|
||||
resources_windows aarch64-w64-mingw32-windres arm64 resources_windows_arm64.syso
|
||||
resources_windows x86_64-w64-mingw32-windres amd64 resources_windows_amd64.syso
|
||||
7
go.mod
7
go.mod
@@ -19,7 +19,7 @@ require (
|
||||
github.com/vishvananda/netlink v1.1.0
|
||||
golang.org/x/crypto v0.7.0
|
||||
golang.org/x/sys v0.6.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
|
||||
golang.zx2c4.com/wireguard/windows v0.5.1
|
||||
google.golang.org/grpc v1.52.3
|
||||
@@ -48,6 +48,8 @@ require (
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/open-policy-agent/opa v0.49.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pion/logging v0.2.2
|
||||
github.com/pion/stun v0.4.0
|
||||
github.com/pion/transport/v2 v2.0.2
|
||||
github.com/prometheus/client_golang v1.14.0
|
||||
github.com/rs/xid v1.3.0
|
||||
@@ -103,10 +105,8 @@ require (
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.6 // indirect
|
||||
github.com/pion/logging v0.2.2 // indirect
|
||||
github.com/pion/mdns v0.0.7 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pion/stun v0.4.0 // indirect
|
||||
github.com/pion/turn/v2 v2.1.0 // indirect
|
||||
github.com/pion/udp/v2 v2.0.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
@@ -131,7 +131,6 @@ require (
|
||||
golang.org/x/mod v0.8.0 // indirect
|
||||
golang.org/x/text v0.8.0 // indirect
|
||||
golang.org/x/tools v0.6.0 // indirect
|
||||
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
|
||||
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
|
||||
5
go.sum
5
go.sum
@@ -881,13 +881,12 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
|
||||
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 h1:3zl8RkJNQ8wfPRomwv/6DBbH2Ut6dgMaWTxM0ZunWnE=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=
|
||||
|
||||
208
iface/bind/bind.go
Normal file
208
iface/bind/bind.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/pion/stun"
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library
|
||||
type ICEBind struct {
|
||||
// below fields, initialized on open
|
||||
ipv4 net.PacketConn
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
|
||||
// below are fields initialized on creation
|
||||
transportNet transport.Net
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewICEBind create a new instance of ICEBind with a given transportNet function.
|
||||
// The transportNet can be nil.
|
||||
func NewICEBind(transportNet transport.Net) *ICEBind {
|
||||
return &ICEBind{
|
||||
transportNet: transportNet,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.udpMux == nil {
|
||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||
}
|
||||
|
||||
return b.udpMux, nil
|
||||
}
|
||||
|
||||
// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching
|
||||
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.ipv4 != nil {
|
||||
return nil, 0, conn.ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
var err error
|
||||
b.ipv4, _, err = listenNet("udp4", int(uport))
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.ipv4, Net: b.transportNet})
|
||||
|
||||
portAddr, err := netip.ParseAddrPort(b.ipv4.LocalAddr().String())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
log.Infof("opened ICEBind on %s", b.ipv4.LocalAddr().String())
|
||||
|
||||
return []conn.ReceiveFunc{
|
||||
b.makeReceiveIPv4(b.ipv4),
|
||||
},
|
||||
portAddr.Port(), nil
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (net.PacketConn, int, error) {
|
||||
c, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
lAddr := c.LocalAddr()
|
||||
uAddr, err := net.ResolveUDPAddr(
|
||||
lAddr.Network(),
|
||||
lAddr.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return c, uAddr.Port, nil
|
||||
}
|
||||
|
||||
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||
msg := &stun.Message{
|
||||
Raw: raw,
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
|
||||
return func(buff []byte) (int, conn.Endpoint, error) {
|
||||
n, endpoint, err := c.ReadFrom(buff)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
e, err := netip.ParseAddrPort(endpoint.String())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if !stun.IsMessage(buff) {
|
||||
// WireGuard traffic
|
||||
return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil
|
||||
}
|
||||
|
||||
msg, err := parseSTUNMessage(buff[:n])
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
|
||||
if err != nil {
|
||||
log.Warnf("failed to handle packet")
|
||||
}
|
||||
|
||||
// discard packets because they are STUN related
|
||||
return 0, nil, nil //todo proper return
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the WireGuard socket and UDPMux
|
||||
func (b *ICEBind) Close() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
var err1, err2 error
|
||||
if b.ipv4 != nil {
|
||||
c := b.ipv4
|
||||
b.ipv4 = nil
|
||||
err1 = c.Close()
|
||||
}
|
||||
|
||||
if b.udpMux != nil {
|
||||
m := b.udpMux
|
||||
b.udpMux = nil
|
||||
err2 = m.Close()
|
||||
}
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
|
||||
return err2
|
||||
}
|
||||
|
||||
// SetMark sets the mark for each packet sent through this Bind.
|
||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
func (b *ICEBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send bytes to the remote endpoint (peer)
|
||||
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
|
||||
|
||||
nend, ok := endpoint.(conn.StdNetEndpoint)
|
||||
if !ok {
|
||||
return conn.ErrWrongEndpointType
|
||||
}
|
||||
addrPort := netip.AddrPort(nend)
|
||||
_, err := b.ipv4.WriteTo(buff, &net.UDPAddr{
|
||||
IP: addrPort.Addr().AsSlice(),
|
||||
Port: int(addrPort.Port()),
|
||||
Zone: addrPort.Addr().Zone(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string.
|
||||
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
return asEndpoint(e), err
|
||||
}
|
||||
|
||||
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
||||
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
|
||||
// but Endpoints are immutable, so we can re-use them.
|
||||
var endpointPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make(map[netip.AddrPort]conn.Endpoint)
|
||||
},
|
||||
}
|
||||
|
||||
// asEndpoint returns an Endpoint containing ap.
|
||||
func asEndpoint(ap netip.AddrPort) conn.Endpoint {
|
||||
m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint)
|
||||
defer endpointPool.Put(m)
|
||||
e, ok := m[ap]
|
||||
if !ok {
|
||||
e = conn.Endpoint(conn.StdNetEndpoint(ap))
|
||||
m[ap] = e
|
||||
}
|
||||
return e
|
||||
}
|
||||
445
iface/bind/udp_mux.go
Normal file
445
iface/bind/udp_mux.go
Normal file
@@ -0,0 +1,445 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/stun"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/transport/v2"
|
||||
)
|
||||
|
||||
/*
|
||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
||||
*/
|
||||
|
||||
const receiveMTU = 8192
|
||||
|
||||
// UDPMuxDefault is an implementation of the interface
|
||||
type UDPMuxDefault struct {
|
||||
params UDPMuxParams
|
||||
|
||||
closedChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
|
||||
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
|
||||
connsIPv4, connsIPv6 map[string]*udpMuxedConn
|
||||
|
||||
addressMapMu sync.RWMutex
|
||||
addressMap map[string][]*udpMuxedConn
|
||||
|
||||
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
|
||||
pool *sync.Pool
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
// for UDP connection listen at unspecified address
|
||||
localAddrsForUnspecified []net.Addr
|
||||
}
|
||||
|
||||
const maxAddrSize = 512
|
||||
|
||||
// UDPMuxParams are parameters for UDPMux.
|
||||
type UDPMuxParams struct {
|
||||
Logger logging.LeveledLogger
|
||||
UDPConn net.PacketConn
|
||||
|
||||
// Required for gathering local addresses
|
||||
// in case a un UDPConn is passed which does not
|
||||
// bind to a specific local address.
|
||||
Net transport.Net
|
||||
InterfaceFilter func(interfaceName string) bool
|
||||
}
|
||||
|
||||
func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []ice.NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
|
||||
ips := []net.IP{}
|
||||
ifaces, err := n.Interfaces()
|
||||
if err != nil {
|
||||
return ips, err
|
||||
}
|
||||
|
||||
var IPv4Requested, IPv6Requested bool
|
||||
for _, typ := range networkTypes {
|
||||
if typ.IsIPv4() {
|
||||
IPv4Requested = true
|
||||
}
|
||||
|
||||
if typ.IsIPv6() {
|
||||
IPv6Requested = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue // interface down
|
||||
}
|
||||
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
|
||||
continue // loopback interface
|
||||
}
|
||||
|
||||
if interfaceFilter != nil && !interfaceFilter(iface.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch addr := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = addr.IP
|
||||
case *net.IPAddr:
|
||||
ip = addr.IP
|
||||
}
|
||||
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ipv4 := ip.To4(); ipv4 == nil {
|
||||
if !IPv6Requested {
|
||||
continue
|
||||
} else if !isSupportedIPv6(ip) {
|
||||
continue
|
||||
}
|
||||
} else if !IPv4Requested {
|
||||
continue
|
||||
}
|
||||
|
||||
if ipFilter != nil && !ipFilter(ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// The conditions of invalidation written below are defined in
|
||||
// https://tools.ietf.org/html/rfc8445#section-5.1.1.1
|
||||
func isSupportedIPv6(ip net.IP) bool {
|
||||
if len(ip) != net.IPv6len ||
|
||||
isZeros(ip[0:12]) || // !(IPv4-compatible IPv6)
|
||||
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast)
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isZeros(ip net.IP) bool {
|
||||
for i := 0; i < len(ip); i++ {
|
||||
if ip[i] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
}
|
||||
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
default:
|
||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
if params.Net == nil {
|
||||
var err error
|
||||
if params.Net, err = stdnet.NewNet(); err != nil {
|
||||
params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||
}
|
||||
} else {
|
||||
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &UDPMuxDefault{
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
connsIPv6: make(map[string]*udpMuxedConn),
|
||||
closedChan: make(chan struct{}, 1),
|
||||
pool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
// big enough buffer to fit both packet and address
|
||||
return newBufferHolder(receiveMTU + maxAddrSize)
|
||||
},
|
||||
},
|
||||
localAddrsForUnspecified: localAddrsForUnspecified,
|
||||
}
|
||||
}
|
||||
|
||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||
func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||
return m.params.UDPConn.LocalAddr()
|
||||
}
|
||||
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
if len(m.localAddrsForUnspecified) > 0 {
|
||||
return m.localAddrsForUnspecified
|
||||
}
|
||||
|
||||
return []net.Addr{m.LocalAddr()}
|
||||
}
|
||||
|
||||
// GetConn returns a PacketConn given the connection's ufrag and network address
|
||||
// creates the connection if an existing one can't be found
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||
|
||||
var isIPv6 bool
|
||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||
isIPv6 = true
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.IsClosed() {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
if conn, ok := m.getConn(ufrag, isIPv6); ok {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
c := m.createMuxedConn(ufrag)
|
||||
go func() {
|
||||
<-c.CloseChannel()
|
||||
m.RemoveConnByUfrag(ufrag)
|
||||
}()
|
||||
|
||||
if isIPv6 {
|
||||
m.connsIPv6[ufrag] = c
|
||||
} else {
|
||||
m.connsIPv4[ufrag] = c
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
removedConns := make([]*udpMuxedConn, 0, 2)
|
||||
|
||||
// Keep lock section small to avoid deadlock with conn lock
|
||||
m.mu.Lock()
|
||||
if c, ok := m.connsIPv4[ufrag]; ok {
|
||||
delete(m.connsIPv4, ufrag)
|
||||
removedConns = append(removedConns, c)
|
||||
}
|
||||
if c, ok := m.connsIPv6[ufrag]; ok {
|
||||
delete(m.connsIPv6, ufrag)
|
||||
removedConns = append(removedConns, c)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(removedConns) == 0 {
|
||||
// No need to lock if no connection was found
|
||||
return
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
for _, c := range removedConns {
|
||||
addresses := c.getAddresses()
|
||||
for _, addr := range addresses {
|
||||
if connList, ok := m.addressMap[addr]; ok {
|
||||
var newList []*udpMuxedConn
|
||||
for _, conn := range connList {
|
||||
if conn.params.Key != ufrag {
|
||||
newList = append(newList, conn)
|
||||
}
|
||||
}
|
||||
m.addressMap[addr] = newList
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsClosed returns true if the mux had been closed
|
||||
func (m *UDPMuxDefault) IsClosed() bool {
|
||||
select {
|
||||
case <-m.closedChan:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Close the mux, no further connections could be created
|
||||
func (m *UDPMuxDefault) Close() error {
|
||||
var err error
|
||||
m.closeOnce.Do(func() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, c := range m.connsIPv4 {
|
||||
_ = c.Close()
|
||||
}
|
||||
for _, c := range m.connsIPv6 {
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
m.connsIPv4 = make(map[string]*udpMuxedConn)
|
||||
m.connsIPv6 = make(map[string]*udpMuxedConn)
|
||||
|
||||
close(m.closedChan)
|
||||
|
||||
_ = m.params.UDPConn.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
return m.params.UDPConn.WriteTo(buf, rAddr)
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
||||
if m.IsClosed() {
|
||||
return
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
existing, ok := m.addressMap[addr]
|
||||
if !ok {
|
||||
existing = []*udpMuxedConn{}
|
||||
}
|
||||
existing = append(existing, conn)
|
||||
m.addressMap[addr] = existing
|
||||
|
||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
|
||||
c := newUDPMuxedConn(&udpMuxedConnParams{
|
||||
Mux: m,
|
||||
Key: key,
|
||||
AddrPool: m.pool,
|
||||
LocalAddr: m.LocalAddr(),
|
||||
Logger: m.params.Logger,
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
|
||||
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
||||
|
||||
remoteAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
|
||||
}
|
||||
|
||||
// If we have already seen this address dispatch to the appropriate destination
|
||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||
// We will then forward STUN packets to each of these connections.
|
||||
m.addressMapMu.Lock()
|
||||
var destinationConnList []*udpMuxedConn
|
||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||
destinationConnList = append(destinationConnList, storedConns...)
|
||||
}
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
var isIPv6 bool
|
||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||
isIPv6 = true
|
||||
}
|
||||
|
||||
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
|
||||
// However, we can take a username attribute from the STUN message which contains ufrag.
|
||||
// We can use ufrag to identify the destination conn to route packet to.
|
||||
attr, stunAttrErr := msg.Get(stun.AttrUsername)
|
||||
if stunAttrErr == nil {
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
|
||||
m.mu.Lock()
|
||||
destinationConn := m.connsIPv4[ufrag]
|
||||
if isIPv6 {
|
||||
destinationConn = m.connsIPv6[ufrag]
|
||||
}
|
||||
|
||||
if destinationConn != nil {
|
||||
exists := false
|
||||
for _, conn := range destinationConnList {
|
||||
if conn.params.Key == destinationConn.params.Key {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
destinationConnList = append(destinationConnList, destinationConn)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
|
||||
// It will be discarded by the further ICE candidate logic if so.
|
||||
for _, conn := range destinationConnList {
|
||||
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
|
||||
log.Errorf("could not write packet: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
||||
if isIPv6 {
|
||||
val, ok = m.connsIPv6[ufrag]
|
||||
} else {
|
||||
val, ok = m.connsIPv4[ufrag]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type bufferHolder struct {
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newBufferHolder(size int) *bufferHolder {
|
||||
return &bufferHolder{
|
||||
buf: make([]byte, size),
|
||||
}
|
||||
}
|
||||
254
iface/bind/udp_mux_universal.go
Normal file
254
iface/bind/udp_mux_universal.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package bind
|
||||
|
||||
/*
|
||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun"
|
||||
"github.com/pion/transport/v2"
|
||||
)
|
||||
|
||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
||||
// It then passes packets to the UDPMux that does the actual connection muxing.
|
||||
type UniversalUDPMuxDefault struct {
|
||||
*UDPMuxDefault
|
||||
params UniversalUDPMuxParams
|
||||
|
||||
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
|
||||
// stun.XORMappedAddress indexed by the STUN server addr
|
||||
xorMappedMap map[string]*xorMapped
|
||||
}
|
||||
|
||||
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
|
||||
type UniversalUDPMuxParams struct {
|
||||
Logger logging.LeveledLogger
|
||||
UDPConn net.PacketConn
|
||||
XORMappedAddrCacheTTL time.Duration
|
||||
Net transport.Net
|
||||
}
|
||||
|
||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
}
|
||||
if params.XORMappedAddrCacheTTL == 0 {
|
||||
params.XORMappedAddrCacheTTL = time.Second * 25
|
||||
}
|
||||
|
||||
m := &UniversalUDPMuxDefault{
|
||||
params: params,
|
||||
xorMappedMap: make(map[string]*xorMapped),
|
||||
}
|
||||
|
||||
// wrap UDP connection, process server reflexive messages
|
||||
// before they are passed to the UDPMux connection handler (connWorker)
|
||||
m.params.UDPConn = &udpConn{
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
}
|
||||
|
||||
// embed UDPMux
|
||||
udpMuxParams := UDPMuxParams{
|
||||
Logger: params.Logger,
|
||||
UDPConn: m.params.UDPConn,
|
||||
Net: m.params.Net,
|
||||
}
|
||||
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
type udpConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
}
|
||||
|
||||
// GetListenAddresses returns the listen addr of this UDP
|
||||
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
return []net.Addr{m.LocalAddr()}
|
||||
}
|
||||
|
||||
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
|
||||
// Not implemented yet.
|
||||
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
|
||||
return nil, fmt.Errorf("not implemented yet")
|
||||
}
|
||||
|
||||
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
||||
// and return a unique connection per server.
|
||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
|
||||
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
|
||||
}
|
||||
|
||||
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
|
||||
// All other STUN packets will be forwarded to the UDPMux
|
||||
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
||||
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
// message about this err will be logged in the UDPMux
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.isXORMappedResponse(msg, udpAddr.String()) {
|
||||
err := m.handleXORMappedResponse(udpAddr, msg)
|
||||
if err != nil {
|
||||
log.Debugf("%s: %v", fmt.Errorf("failed to get XOR-MAPPED-ADDRESS response"), err)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
|
||||
}
|
||||
|
||||
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
|
||||
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
|
||||
_, ok := m.xorMappedMap[stunAddr]
|
||||
_, err := msg.Get(stun.AttrXORMappedAddress)
|
||||
return err == nil && ok
|
||||
}
|
||||
|
||||
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
|
||||
// and set the mapped address for the server
|
||||
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
|
||||
if !ok {
|
||||
return fmt.Errorf("no XOR address mapping")
|
||||
}
|
||||
|
||||
var addr stun.XORMappedAddress
|
||||
if err := addr.GetFrom(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.xorMappedMap[stunAddr.String()] = mappedAddr
|
||||
mappedAddr.SetAddr(&addr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
|
||||
// Makes a STUN binding request to discover mapped address otherwise.
|
||||
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
|
||||
// Method is safe for concurrent use.
|
||||
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
|
||||
m.mu.Lock()
|
||||
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
|
||||
// if we already have a mapping for this STUN server (address already received)
|
||||
// and if it is not too old we return it without making a new request to STUN server
|
||||
if ok {
|
||||
if mappedAddr.expired() {
|
||||
mappedAddr.closeWaiters()
|
||||
delete(m.xorMappedMap, serverAddr.String())
|
||||
ok = false
|
||||
} else if mappedAddr.pending() {
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if ok {
|
||||
return mappedAddr.addr, nil
|
||||
}
|
||||
|
||||
// otherwise, make a STUN request to discover the address
|
||||
// or wait for already sent request to complete
|
||||
waitAddrReceived, err := m.sendStun(serverAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %s", "failed to send STUN packet", err)
|
||||
}
|
||||
|
||||
// block until response was handled by the connWorker routine and XORMappedAddress was updated
|
||||
select {
|
||||
case <-waitAddrReceived:
|
||||
// when channel closed, addr was obtained
|
||||
m.mu.Lock()
|
||||
mappedAddr := *m.xorMappedMap[serverAddr.String()]
|
||||
m.mu.Unlock()
|
||||
if mappedAddr.addr == nil {
|
||||
return nil, fmt.Errorf("no XOR address mapping")
|
||||
}
|
||||
return mappedAddr.addr, nil
|
||||
case <-time.After(deadline):
|
||||
return nil, fmt.Errorf("timeout while waiting for XORMappedAddr")
|
||||
}
|
||||
}
|
||||
|
||||
// sendStun sends a STUN request via UDP conn.
|
||||
//
|
||||
// The returned channel is closed when the STUN response has been received.
|
||||
// Method is safe for concurrent use.
|
||||
func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// if record present in the map, we already sent a STUN request,
|
||||
// just wait when waitAddrReceived will be closed
|
||||
addrMap, ok := m.xorMappedMap[serverAddr.String()]
|
||||
if !ok {
|
||||
addrMap = &xorMapped{
|
||||
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
|
||||
waitAddrReceived: make(chan struct{}),
|
||||
}
|
||||
m.xorMappedMap[serverAddr.String()] = addrMap
|
||||
}
|
||||
|
||||
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return addrMap.waitAddrReceived, nil
|
||||
}
|
||||
|
||||
type xorMapped struct {
|
||||
addr *stun.XORMappedAddress
|
||||
waitAddrReceived chan struct{}
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func (a *xorMapped) closeWaiters() {
|
||||
select {
|
||||
case <-a.waitAddrReceived:
|
||||
// notify was close, ok, that means we received duplicate response
|
||||
// just exit
|
||||
break
|
||||
default:
|
||||
// notify tha twe have a new addr
|
||||
close(a.waitAddrReceived)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *xorMapped) pending() bool {
|
||||
return a.addr == nil
|
||||
}
|
||||
|
||||
func (a *xorMapped) expired() bool {
|
||||
return a.expiresAt.Before(time.Now())
|
||||
}
|
||||
|
||||
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
|
||||
a.addr = addr
|
||||
a.closeWaiters()
|
||||
}
|
||||
233
iface/bind/udp_muxed_conn.go
Normal file
233
iface/bind/udp_muxed_conn.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package bind
|
||||
|
||||
/*
|
||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
||||
*/
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/transport/v2/packetio"
|
||||
)
|
||||
|
||||
type udpMuxedConnParams struct {
|
||||
Mux *UDPMuxDefault
|
||||
AddrPool *sync.Pool
|
||||
Key string
|
||||
LocalAddr net.Addr
|
||||
Logger logging.LeveledLogger
|
||||
}
|
||||
|
||||
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
||||
type udpMuxedConn struct {
|
||||
params *udpMuxedConnParams
|
||||
// remote addresses that we have sent to on this conn
|
||||
addresses []string
|
||||
|
||||
// channel holding incoming packets
|
||||
buf *packetio.Buffer
|
||||
closedChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
|
||||
p := &udpMuxedConn{
|
||||
params: params,
|
||||
buf: packetio.NewBuffer(),
|
||||
closedChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
|
||||
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
|
||||
defer c.params.AddrPool.Put(buf)
|
||||
|
||||
// read address
|
||||
total, err := c.buf.Read(buf.buf)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
|
||||
if dataLen > total || dataLen > len(b) {
|
||||
return 0, nil, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
// read data and then address
|
||||
offset := 2
|
||||
copy(b, buf.buf[offset:offset+dataLen])
|
||||
offset += dataLen
|
||||
|
||||
// read address len & decode address
|
||||
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
|
||||
offset += 2
|
||||
|
||||
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
return dataLen, rAddr, nil
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
// each time we write to a new address, we'll register it with the mux
|
||||
addr := rAddr.String()
|
||||
if !c.containsAddress(addr) {
|
||||
c.addAddress(addr)
|
||||
}
|
||||
|
||||
return c.params.Mux.writeTo(buf, rAddr)
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) LocalAddr() net.Addr {
|
||||
return c.params.LocalAddr
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
|
||||
return c.closedChan
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) Close() error {
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
err = c.buf.Close()
|
||||
close(c.closedChan)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) isClosed() bool {
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) getAddresses() []string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
addresses := make([]string, len(c.addresses))
|
||||
copy(addresses, c.addresses)
|
||||
return addresses
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) addAddress(addr string) {
|
||||
c.mu.Lock()
|
||||
c.addresses = append(c.addresses, addr)
|
||||
c.mu.Unlock()
|
||||
|
||||
// map it on mux
|
||||
c.params.Mux.registerConnForAddress(c, addr)
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) containsAddress(addr string) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for _, a := range c.addresses {
|
||||
if addr == a {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
|
||||
// write two packets, address and data
|
||||
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
|
||||
defer c.params.AddrPool.Put(buf)
|
||||
|
||||
// format of buffer | data len | data bytes | addr len | addr bytes |
|
||||
if len(buf.buf) < len(data)+maxAddrSize {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
// data len
|
||||
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
|
||||
offset := 2
|
||||
|
||||
// data
|
||||
copy(buf.buf[offset:], data)
|
||||
offset += len(data)
|
||||
|
||||
// write address first, leaving room for its length
|
||||
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
total := offset + n + 2
|
||||
|
||||
// address len
|
||||
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
|
||||
|
||||
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
|
||||
ipData, err := addr.IP.MarshalText()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total := 2 + len(ipData) + 2 + len(addr.Zone)
|
||||
if total > len(buf) {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
|
||||
offset := 2
|
||||
n := copy(buf[offset:], ipData)
|
||||
offset += n
|
||||
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
|
||||
offset += 2
|
||||
copy(buf[offset:], addr.Zone)
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
|
||||
addr := net.UDPAddr{}
|
||||
|
||||
offset := 0
|
||||
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
|
||||
offset += 2
|
||||
// basic bounds checking
|
||||
if ipLen+offset > len(buf) {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset += ipLen
|
||||
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
|
||||
offset += 2
|
||||
zone := make([]byte, len(buf[offset:]))
|
||||
copy(zone, buf[offset:])
|
||||
addr.Zone = string(zone)
|
||||
|
||||
return &addr, nil
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -16,9 +18,20 @@ const (
|
||||
|
||||
// WGIface represents a interface instance
|
||||
type WGIface struct {
|
||||
tun *tunDevice
|
||||
configurer wGConfigurer
|
||||
mu sync.Mutex
|
||||
tun *tunDevice
|
||||
configurer wGConfigurer
|
||||
mu sync.Mutex
|
||||
userspaceBind bool
|
||||
}
|
||||
|
||||
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
||||
func (w *WGIface) IsUserspaceBind() bool {
|
||||
return w.userspaceBind
|
||||
}
|
||||
|
||||
// GetBind returns a userspace implementation of WireGuard Bind interface
|
||||
func (w *WGIface) GetBind() *bind.ICEBind {
|
||||
return w.tun.iceBind
|
||||
}
|
||||
|
||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
@@ -26,7 +39,7 @@ type WGIface struct {
|
||||
func (w *WGIface) Create() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
log.Debugf("create Wireguard interface %s", w.tun.DeviceName())
|
||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
||||
return w.tun.Create()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
package iface
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"sync"
|
||||
|
||||
// NewWGIFace Creates a new Wireguard interface instance
|
||||
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) {
|
||||
wgIface := &WGIface{
|
||||
"github.com/pion/transport/v2"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
|
||||
wgAddress, err := parseWGAddress(address)
|
||||
if err != nil {
|
||||
return wgIface, err
|
||||
return wgIFace, err
|
||||
}
|
||||
|
||||
tun := newTunDevice(wgAddress, mtu, tunAdapter)
|
||||
wgIface.tun = tun
|
||||
tun := newTunDevice(wgAddress, mtu, routes, tunAdapter, transportNet)
|
||||
wgIFace.tun = tun
|
||||
|
||||
wgIface.configurer = newWGConfigurer(tun)
|
||||
wgIFace.configurer = newWGConfigurer(tun)
|
||||
|
||||
return wgIface, nil
|
||||
wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
|
||||
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -2,21 +2,26 @@
|
||||
|
||||
package iface
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"sync"
|
||||
|
||||
// NewWGIFace Creates a new Wireguard interface instance
|
||||
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) {
|
||||
wgIface := &WGIface{
|
||||
"github.com/pion/transport/v2"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
|
||||
wgAddress, err := parseWGAddress(address)
|
||||
if err != nil {
|
||||
return wgIface, err
|
||||
return wgIFace, err
|
||||
}
|
||||
|
||||
wgIface.tun = newTunDevice(ifaceName, wgAddress, mtu)
|
||||
wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet)
|
||||
|
||||
wgIface.configurer = newWGConfigurer(ifaceName)
|
||||
return wgIface, nil
|
||||
wgIFace.configurer = newWGConfigurer(iFaceName)
|
||||
wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -2,13 +2,15 @@ package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// keep darwin compability
|
||||
@@ -32,7 +34,12 @@ func init() {
|
||||
func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
addr := "100.64.0.1/8"
|
||||
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -92,7 +99,11 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
||||
func Test_CreateInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||
wgIP := "10.99.99.1/32"
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -121,7 +132,11 @@ func Test_CreateInterface(t *testing.T) {
|
||||
func Test_Close(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -149,7 +164,11 @@ func Test_Close(t *testing.T) {
|
||||
func Test_ConfigureInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||
wgIP := "10.99.99.5/30"
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -196,7 +215,11 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
func Test_UpdatePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.9/30"
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -255,7 +278,11 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
func Test_RemovePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.13/30"
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -304,8 +331,11 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
peer2Key, _ := wgtypes.GeneratePrivateKey()
|
||||
|
||||
keepAlive := 1 * time.Second
|
||||
|
||||
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -322,7 +352,11 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil)
|
||||
newNet, err = stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
package iface
|
||||
|
||||
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||
func WireguardModuleIsLoaded() bool {
|
||||
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -7,9 +7,6 @@ import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"io"
|
||||
"io/fs"
|
||||
"math"
|
||||
@@ -17,6 +14,10 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Holds logic to check existence of kernel modules used by wireguard interfaces
|
||||
@@ -33,6 +34,7 @@ const (
|
||||
loading
|
||||
live
|
||||
inuse
|
||||
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
|
||||
)
|
||||
|
||||
type module struct {
|
||||
@@ -81,9 +83,15 @@ func tunModuleIsLoaded() bool {
|
||||
return tunLoaded
|
||||
}
|
||||
|
||||
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
|
||||
func WireguardModuleIsLoaded() bool {
|
||||
if canCreateFakeWireguardInterface() {
|
||||
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
|
||||
if os.Getenv(envDisableWireGuardKernel) == "true" {
|
||||
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
|
||||
return false
|
||||
}
|
||||
|
||||
if canCreateFakeWireGuardInterface() {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -96,7 +104,7 @@ func WireguardModuleIsLoaded() bool {
|
||||
return loaded
|
||||
}
|
||||
|
||||
func canCreateFakeWireguardInterface() bool {
|
||||
func canCreateFakeWireGuardInterface() bool {
|
||||
link := newWGLink("mustnotexist")
|
||||
|
||||
// We willingly try to create a device with an invalid
|
||||
|
||||
@@ -3,13 +3,14 @@ package iface
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func TestGetModuleDependencies(t *testing.T) {
|
||||
|
||||
@@ -2,6 +2,6 @@ package iface
|
||||
|
||||
// TunAdapter is an interface for create tun device from externel service
|
||||
type TunAdapter interface {
|
||||
ConfigureInterface(address string, mtu int) (int, error)
|
||||
ConfigureInterface(address string, mtu int, routes string) (int, error)
|
||||
UpdateAddr(address string) error
|
||||
}
|
||||
|
||||
@@ -1,38 +1,43 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
)
|
||||
|
||||
type tunDevice struct {
|
||||
address WGAddress
|
||||
mtu int
|
||||
routes []string
|
||||
tunAdapter TunAdapter
|
||||
|
||||
fd int
|
||||
name string
|
||||
device *device.Device
|
||||
uapi net.Listener
|
||||
fd int
|
||||
name string
|
||||
device *device.Device
|
||||
iceBind *bind.ICEBind
|
||||
}
|
||||
|
||||
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter) *tunDevice {
|
||||
func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
|
||||
return &tunDevice{
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
routes: routes,
|
||||
tunAdapter: tunAdapter,
|
||||
iceBind: bind.NewICEBind(transportNet),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tunDevice) Create() error {
|
||||
var err error
|
||||
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu)
|
||||
routesString := t.routesToString()
|
||||
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, routesString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
return err
|
||||
@@ -46,35 +51,11 @@ func (t *tunDevice) Create() error {
|
||||
t.name = name
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(tunDevice, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||
t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||
t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
log.Debugf("create uapi")
|
||||
tunSock, err := ipc.UAPIOpen(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.uapi, err = ipc.UAPIListen(name, tunSock)
|
||||
if err != nil {
|
||||
tunSock.Close()
|
||||
unix.Close(t.fd)
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
uapiConn, err := t.uapi.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go t.device.IpcHandle(uapiConn)
|
||||
}
|
||||
}()
|
||||
|
||||
err = t.device.Up()
|
||||
if err != nil {
|
||||
tunSock.Close()
|
||||
t.device.Close()
|
||||
return err
|
||||
}
|
||||
@@ -100,13 +81,13 @@ func (t *tunDevice) UpdateAddr(addr WGAddress) error {
|
||||
}
|
||||
|
||||
func (t *tunDevice) Close() (err error) {
|
||||
if t.uapi != nil {
|
||||
err = t.uapi.Close()
|
||||
}
|
||||
|
||||
if t.device != nil {
|
||||
t.device.Close()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t *tunDevice) routesToString() string {
|
||||
return strings.Join(t.routes, ";")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func (c *tunDevice) Create() error {
|
||||
if WireguardModuleIsLoaded() {
|
||||
if WireGuardModuleIsLoaded() {
|
||||
log.Info("using kernel WireGuard")
|
||||
return c.createWithKernel()
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (c *tunDevice) Create() error {
|
||||
|
||||
}
|
||||
|
||||
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module.
|
||||
// createWithKernel Creates a new WireGuard interface using kernel WireGuard module.
|
||||
// Works for Linux and offers much better network performance
|
||||
func (c *tunDevice) createWithKernel() error {
|
||||
|
||||
|
||||
@@ -6,10 +6,13 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"github.com/pion/transport/v2"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
@@ -18,13 +21,18 @@ type tunDevice struct {
|
||||
address WGAddress
|
||||
mtu int
|
||||
netInterface NetInterface
|
||||
iceBind *bind.ICEBind
|
||||
uapi net.Listener
|
||||
close chan struct{}
|
||||
}
|
||||
|
||||
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice {
|
||||
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
|
||||
return &tunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
iceBind: bind.NewICEBind(transportNet),
|
||||
close: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,23 +50,38 @@ func (c *tunDevice) DeviceName() string {
|
||||
}
|
||||
|
||||
func (c *tunDevice) Close() error {
|
||||
if c.netInterface == nil {
|
||||
return nil
|
||||
|
||||
select {
|
||||
case c.close <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
err := c.netInterface.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
var err1, err2, err3 error
|
||||
if c.netInterface != nil {
|
||||
err1 = c.netInterface.Close()
|
||||
}
|
||||
|
||||
if c.uapi != nil {
|
||||
err2 = c.uapi.Close()
|
||||
}
|
||||
|
||||
sockPath := "/var/run/wireguard/" + c.name + ".sock"
|
||||
if _, statErr := os.Stat(sockPath); statErr == nil {
|
||||
statErr = os.Remove(sockPath)
|
||||
if statErr != nil {
|
||||
return statErr
|
||||
err3 = statErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
|
||||
return err3
|
||||
}
|
||||
|
||||
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
|
||||
@@ -69,26 +92,36 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
|
||||
}
|
||||
|
||||
// We need to create a wireguard-go device and listen to configuration requests
|
||||
tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||
err = tunDevice.Up()
|
||||
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
||||
err = tunDev.Up()
|
||||
if err != nil {
|
||||
return tunIface, err
|
||||
_ = tunIface.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// todo: after this line in case of error close the tunSock
|
||||
uapi, err := c.getUAPI(c.name)
|
||||
c.uapi, err = c.getUAPI(c.name)
|
||||
if err != nil {
|
||||
return tunIface, err
|
||||
_ = tunIface.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
uapiConn, uapiErr := uapi.Accept()
|
||||
select {
|
||||
case <-c.close:
|
||||
log.Debugf("exit uapi.Accept()")
|
||||
return
|
||||
default:
|
||||
}
|
||||
uapiConn, uapiErr := c.uapi.Accept()
|
||||
if uapiErr != nil {
|
||||
log.Traceln("uapi Accept failed with error: ", uapiErr)
|
||||
continue
|
||||
}
|
||||
go tunDevice.IpcHandle(uapiConn)
|
||||
go func() {
|
||||
tunDev.IpcHandle(uapiConn)
|
||||
log.Debugf("exit tunDevice.IpcHandle")
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -4,24 +4,39 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/windows/driver"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
)
|
||||
|
||||
type tunDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
netInterface NetInterface
|
||||
iceBind *bind.ICEBind
|
||||
mtu int
|
||||
uapi net.Listener
|
||||
close chan struct{}
|
||||
}
|
||||
|
||||
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice {
|
||||
return &tunDevice{name: name, address: address}
|
||||
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
|
||||
return &tunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
iceBind: bind.NewICEBind(transportNet),
|
||||
close: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *tunDevice) Create() error {
|
||||
var err error
|
||||
c.netInterface, err = c.createAdapter()
|
||||
c.netInterface, err = c.createWithUserspace()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -29,6 +44,51 @@ func (c *tunDevice) Create() error {
|
||||
return c.assignAddr()
|
||||
}
|
||||
|
||||
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
|
||||
func (c *tunDevice) createWithUserspace() (NetInterface, error) {
|
||||
tunIface, err := tun.CreateTUN(c.name, c.mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We need to create a wireguard-go device and listen to configuration requests
|
||||
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
||||
err = tunDev.Up()
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.uapi, err = c.getUAPI(c.name)
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.close:
|
||||
log.Debugf("exit uapi.Accept()")
|
||||
return
|
||||
default:
|
||||
}
|
||||
uapiConn, uapiErr := c.uapi.Accept()
|
||||
if uapiErr != nil {
|
||||
log.Traceln("uapi Accept failed with error: ", uapiErr)
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
tunDev.IpcHandle(uapiConn)
|
||||
log.Debugf("exit tunDevice.IpcHandle")
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debugln("UAPI listener started")
|
||||
return tunIface, nil
|
||||
}
|
||||
|
||||
func (c *tunDevice) UpdateAddr(address WGAddress) error {
|
||||
c.address = address
|
||||
return c.assignAddr()
|
||||
@@ -43,19 +103,33 @@ func (c *tunDevice) DeviceName() string {
|
||||
}
|
||||
|
||||
func (c *tunDevice) Close() error {
|
||||
if c.netInterface == nil {
|
||||
return nil
|
||||
select {
|
||||
case c.close <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
return c.netInterface.Close()
|
||||
var err1, err2 error
|
||||
if c.netInterface != nil {
|
||||
err1 = c.netInterface.Close()
|
||||
}
|
||||
|
||||
if c.uapi != nil {
|
||||
err2 = c.uapi.Close()
|
||||
}
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
|
||||
return err2
|
||||
}
|
||||
|
||||
func (c *tunDevice) getInterfaceGUIDString() (string, error) {
|
||||
if c.netInterface == nil {
|
||||
return "", fmt.Errorf("interface has not been initialized yet")
|
||||
}
|
||||
windowsDevice := c.netInterface.(*driver.Adapter)
|
||||
luid := windowsDevice.LUID()
|
||||
windowsDevice := c.netInterface.(*tun.NativeTun)
|
||||
luid := winipcfg.LUID(windowsDevice.LUID())
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -63,31 +137,15 @@ func (c *tunDevice) getInterfaceGUIDString() (string, error) {
|
||||
return guid.String(), nil
|
||||
}
|
||||
|
||||
func (c *tunDevice) createAdapter() (NetInterface, error) {
|
||||
WintunStaticRequestedGUID, _ := windows.GenerateGUID()
|
||||
adapter, err := driver.CreateAdapter(c.name, "WireGuard", &WintunStaticRequestedGUID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error creating adapter: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
err = adapter.SetAdapterState(driver.AdapterStateUp)
|
||||
if err != nil {
|
||||
return adapter, err
|
||||
}
|
||||
state, _ := adapter.LUID().GUID()
|
||||
log.Debugln("device guid: ", state.String())
|
||||
return adapter, nil
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (c *tunDevice) assignAddr() error {
|
||||
luid := c.netInterface.(*driver.Adapter).LUID()
|
||||
|
||||
tunDev := c.netInterface.(*tun.NativeTun)
|
||||
luid := winipcfg.LUID(tunDev.LUID())
|
||||
log.Debugf("adding address %s to interface: %s", c.address.IP, c.name)
|
||||
err := luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
|
||||
}
|
||||
|
||||
// getUAPI returns a Listener
|
||||
func (c *tunDevice) getUAPI(iface string) (net.Listener, error) {
|
||||
return ipc.UAPIListen(iface)
|
||||
}
|
||||
|
||||
@@ -172,6 +172,49 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRoutes return with the routes
|
||||
func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf("failed getting Management Service public key: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||
defer cancelStream()
|
||||
stream, err := c.connectToStream(ctx, *serverPubKey)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open Management Service stream: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = stream.CloseSend()
|
||||
}()
|
||||
|
||||
update, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
log.Debugf("Management stream has been closed by server: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
log.Debugf("disconnected from Management Service sync stream: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decryptedResp := &proto.SyncResponse{}
|
||||
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed decrypting update message from Management Service: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if decryptedResp.GetNetworkMap() == nil {
|
||||
return nil, fmt.Errorf("invalid msg, required network map")
|
||||
}
|
||||
|
||||
return decryptedResp.GetNetworkMap().GetRoutes(), nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
||||
req := &proto.SyncRequest{}
|
||||
|
||||
|
||||
@@ -49,15 +49,17 @@ type AccountManager interface {
|
||||
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
|
||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
CreateUser(accountID, userID string, key *UserInfo) (*UserInfo, error)
|
||||
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(accountID, executingUserID string, targetUserID string) error
|
||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||
GetAccountByUserID(userID string) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
MarkPATUsed(tokenID string) error
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
IsUserAdmin(userID string) (bool, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeerByKey(peerKey string) (*Peer, error)
|
||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||
@@ -171,12 +173,13 @@ type Account struct {
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
}
|
||||
|
||||
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
@@ -283,7 +286,7 @@ func (a *Account) GetGroup(groupID string) *Group {
|
||||
|
||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
||||
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
|
||||
aclPeers, _ := a.getPeersByPolicy(peerID)
|
||||
aclPeers := a.getPeersByACL(peerID)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*Peer
|
||||
var expiredPeers []*Peer
|
||||
@@ -1228,9 +1231,11 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
err = am.redeemInvite(account, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
if !user.IsServiceUser {
|
||||
err = am.redeemInvite(account, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
|
||||
@@ -87,6 +87,10 @@ const (
|
||||
PersonalAccessTokenCreated
|
||||
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
||||
PersonalAccessTokenDeleted
|
||||
// ServiceUserCreated indicates that a user created a service user
|
||||
ServiceUserCreated
|
||||
// ServiceUserDeleted indicates that a user deleted a service user
|
||||
ServiceUserDeleted
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -176,6 +180,10 @@ const (
|
||||
PersonalAccessTokenCreatedMessage string = "Personal access token created"
|
||||
// PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity
|
||||
PersonalAccessTokenDeletedMessage string = "Personal access token deleted"
|
||||
// ServiceUserCreatedMessage is a human-readable text message of the ServiceUserCreated activity
|
||||
ServiceUserCreatedMessage string = "Service user created"
|
||||
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
|
||||
ServiceUserDeletedMessage string = "Service user deleted"
|
||||
)
|
||||
|
||||
// Activity that triggered an Event
|
||||
@@ -270,6 +278,10 @@ func (a Activity) Message() string {
|
||||
return PersonalAccessTokenCreatedMessage
|
||||
case PersonalAccessTokenDeleted:
|
||||
return PersonalAccessTokenDeletedMessage
|
||||
case ServiceUserCreated:
|
||||
return ServiceUserCreatedMessage
|
||||
case ServiceUserDeleted:
|
||||
return ServiceUserDeletedMessage
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
@@ -364,6 +376,10 @@ func (a Activity) StringCode() string {
|
||||
return "personal.access.token.create"
|
||||
case PersonalAccessTokenDeleted:
|
||||
return "personal.access.token.delete"
|
||||
case ServiceUserCreated:
|
||||
return "service.user.create"
|
||||
case ServiceUserDeleted:
|
||||
return "service.user.delete"
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
|
||||
@@ -286,9 +286,7 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
|
||||
}
|
||||
|
||||
if accountCopy.Rules == nil {
|
||||
accountCopy.Rules = make(map[string]*Rule)
|
||||
}
|
||||
accountCopy.Rules = make(map[string]*Rule)
|
||||
for _, policy := range accountCopy.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
accountCopy.Rules[rule.ID] = rule.ToRule()
|
||||
|
||||
@@ -77,6 +77,10 @@ components:
|
||||
description: Is true if authenticated user is the same as this user
|
||||
type: boolean
|
||||
readOnly: true
|
||||
is_service_user:
|
||||
description: Is true if this user is a service user
|
||||
type: boolean
|
||||
readOnly: true
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
@@ -115,10 +119,13 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
is_service_user:
|
||||
description: Is true if this user is a service user
|
||||
type: boolean
|
||||
required:
|
||||
- role
|
||||
- auto_groups
|
||||
- email
|
||||
- is_service_user
|
||||
PeerMinimum:
|
||||
type: object
|
||||
properties:
|
||||
@@ -825,6 +832,12 @@ paths:
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
parameters:
|
||||
- in: query
|
||||
name: service_user
|
||||
schema:
|
||||
type: boolean
|
||||
description: Filters users and returns either normal users or service users
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON array of Users
|
||||
@@ -903,6 +916,30 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a User
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The User ID
|
||||
responses:
|
||||
'200':
|
||||
description: Delete status code
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{userId}/tokens:
|
||||
get:
|
||||
summary: Returns a list of all tokens for a user
|
||||
|
||||
@@ -677,6 +677,9 @@ type User struct {
|
||||
// IsCurrent Is true if authenticated user is the same as this user
|
||||
IsCurrent *bool `json:"is_current,omitempty"`
|
||||
|
||||
// IsServiceUser Is true if this user is a service user
|
||||
IsServiceUser *bool `json:"is_service_user,omitempty"`
|
||||
|
||||
// Name User's name from idp provider
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -696,7 +699,10 @@ type UserCreateRequest struct {
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Email User's Email to send invite to
|
||||
Email string `json:"email"`
|
||||
Email *string `json:"email,omitempty"`
|
||||
|
||||
// IsServiceUser Is true if this user is a service user
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
|
||||
// Name User's full name
|
||||
Name *string `json:"name,omitempty"`
|
||||
@@ -787,6 +793,12 @@ type PutApiRulesIdJSONBody struct {
|
||||
Sources *[]string `json:"sources,omitempty"`
|
||||
}
|
||||
|
||||
// GetApiUsersParams defines parameters for GetApiUsers.
|
||||
type GetApiUsersParams struct {
|
||||
// ServiceUser Filters users and returns either normal users or service users
|
||||
ServiceUser *bool `form:"service_user,omitempty" json:"service_user,omitempty"`
|
||||
}
|
||||
|
||||
// PutApiAccountsIdJSONRequestBody defines body for PutApiAccountsId for application/json ContentType.
|
||||
type PutApiAccountsIdJSONRequestBody PutApiAccountsIdJSONBody
|
||||
|
||||
|
||||
@@ -111,6 +111,7 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
|
||||
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{id}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
type IsUserAdminFunc func(userID string) (bool, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
@@ -37,7 +37,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
ok, err := a.isUserAdmin(claims)
|
||||
ok, err := a.isUserAdmin(claims.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
|
||||
@@ -3,8 +3,10 @@ package http
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
@@ -77,6 +79,36 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// DeleteUser is a DELETE request to delete a user (only works for service users right now)
|
||||
func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["id"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -103,11 +135,17 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
email := ""
|
||||
if req.Email != nil {
|
||||
email = *req.Email
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{
|
||||
Email: req.Email,
|
||||
Name: *req.Name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
Email: email,
|
||||
Name: *req.Name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
IsServiceUser: req.IsServiceUser,
|
||||
})
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@@ -137,9 +175,27 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
serviceUser := r.URL.Query().Get("service_user")
|
||||
|
||||
log.Debugf("UserCount: %v", len(data))
|
||||
|
||||
users := make([]*api.User, 0)
|
||||
for _, r := range data {
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
continue
|
||||
}
|
||||
includeServiceUser, err := strconv.ParseBool(serviceUser)
|
||||
log.Debugf("Should include service user: %v", includeServiceUser)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
return
|
||||
}
|
||||
log.Debugf("User %v is service user: %v", r.Name, r.IsServiceUser)
|
||||
if includeServiceUser == r.IsServiceUser {
|
||||
log.Debugf("Found service user: %v", r.Name)
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, users)
|
||||
@@ -163,12 +219,13 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,52 +1,91 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
func initUsers(user ...*server.User) *UsersHandler {
|
||||
const (
|
||||
serviceUserID = "serviceUserID"
|
||||
regularUserID = "regularUserID"
|
||||
)
|
||||
|
||||
var usersTestAccount = &server.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: domain,
|
||||
Users: map[string]*server.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: false,
|
||||
},
|
||||
regularUserID: {
|
||||
Id: regularUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: false,
|
||||
},
|
||||
serviceUserID: {
|
||||
Id: serviceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func initUsersTestData() *UsersHandler {
|
||||
return &UsersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
users := make(map[string]*server.User, 0)
|
||||
for _, u := range user {
|
||||
users[u.Id] = u
|
||||
}
|
||||
return &server.Account{
|
||||
Id: "12345",
|
||||
Domain: "netbird.io",
|
||||
Users: users,
|
||||
}, users[claims.UserId], nil
|
||||
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
for _, v := range user {
|
||||
for _, v := range usersTestAccount.Users {
|
||||
users = append(users, &server.UserInfo{
|
||||
ID: v.Id,
|
||||
Role: string(v.Role),
|
||||
Name: "",
|
||||
Email: "",
|
||||
ID: v.Id,
|
||||
Role: string(v.Role),
|
||||
Name: "",
|
||||
Email: "",
|
||||
IsServiceUser: v.IsServiceUser,
|
||||
})
|
||||
}
|
||||
return users, nil
|
||||
},
|
||||
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
return key, nil
|
||||
},
|
||||
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
if !usersTestAccount.Users[targetUserID].IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "user with ID %s is not a service user and can not be deleted", targetUserID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "1",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
UserId: existingUserID,
|
||||
Domain: domain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
@@ -54,8 +93,84 @@ func initUsers(user ...*server.User) *UsersHandler {
|
||||
}
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}}
|
||||
userHandler := initUsers(users...)
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestType string
|
||||
requestPath string
|
||||
expectedUserIDs []string
|
||||
}{
|
||||
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
|
||||
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
|
||||
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
|
||||
userHandler.GetAllUsers(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
respBody := []*server.UserInfo{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(respBody), len(tc.expectedUserIDs))
|
||||
for _, v := range respBody {
|
||||
assert.Contains(t, tc.expectedUserIDs, v.ID)
|
||||
assert.Equal(t, v.ID, usersTestAccount.Users[v.ID].Id)
|
||||
assert.Equal(t, v.Role, string(usersTestAccount.Users[v.ID].Role))
|
||||
assert.Equal(t, v.IsServiceUser, usersTestAccount.Users[v.ID].IsServiceUser)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
name := "name"
|
||||
email := "email"
|
||||
serviceUserToAdd := api.UserCreateRequest{
|
||||
AutoGroups: []string{},
|
||||
Email: nil,
|
||||
IsServiceUser: true,
|
||||
Name: &name,
|
||||
Role: "admin",
|
||||
}
|
||||
serviceUserString, err := json.Marshal(serviceUserToAdd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
regularUserToAdd := api.UserCreateRequest{
|
||||
AutoGroups: []string{},
|
||||
Email: &email,
|
||||
IsServiceUser: true,
|
||||
Name: &name,
|
||||
Role: "admin",
|
||||
}
|
||||
regularUserString, err := json.Marshal(regularUserToAdd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
@@ -65,40 +180,79 @@ func TestGetUsers(t *testing.T) {
|
||||
requestBody io.Reader
|
||||
expectedResult []*server.User
|
||||
}{
|
||||
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users},
|
||||
{name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)},
|
||||
// right now creation is blocked in AC middleware, will be refactored in the future
|
||||
{name: "CreateRegularUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(regularUserString)},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.GetAllUsers(rr, req)
|
||||
userHandler.CreateUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusOK)
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
respBody := []*server.UserInfo{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
if tc.expectedResult != nil {
|
||||
for i, resp := range respBody {
|
||||
assert.Equal(t, resp.ID, tc.expectedResult[i].Id)
|
||||
assert.Equal(t, string(resp.Role), string(tc.expectedResult[i].Role))
|
||||
}
|
||||
status, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestType string
|
||||
requestPath string
|
||||
requestVars map[string]string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{
|
||||
name: "Delete Regular User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
requestVars: map[string]string{"id": regularUserID},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Delete Service User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + serviceUserID,
|
||||
requestVars: map[string]string{"id": serviceUserID},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete Not Existing User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + notFoundUserID,
|
||||
requestVars: map[string]string{"id": notFoundUserID},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.DeleteUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,12 +15,12 @@ import (
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||
GetAccountByUserIDFunc func(userID string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
IsUserAdminFunc func(userID string) (bool, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||
@@ -61,6 +61,7 @@ type MockAccountManager struct {
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
@@ -112,12 +113,12 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
)
|
||||
}
|
||||
|
||||
// GetAccountByUser mock implementation of GetAccountByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, error) {
|
||||
if am.GetAccountByUserFunc != nil {
|
||||
return am.GetAccountByUserFunc(userId)
|
||||
// GetAccountByUserID mock implementation of GetAccountByUserID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByUserID(userID string) (*server.Account, error) {
|
||||
if am.GetAccountByUserIDFunc != nil {
|
||||
return am.GetAccountByUserIDFunc(userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserID is not implemented")
|
||||
}
|
||||
|
||||
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
|
||||
@@ -394,9 +395,9 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
|
||||
}
|
||||
|
||||
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
|
||||
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
func (am *MockAccountManager) IsUserAdmin(userID string) (bool, error) {
|
||||
if am.IsUserAdminFunc != nil {
|
||||
return am.IsUserAdminFunc(claims)
|
||||
return am.IsUserAdminFunc(userID)
|
||||
}
|
||||
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
|
||||
}
|
||||
@@ -500,6 +501,14 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
|
||||
if am.DeleteUserFunc != nil {
|
||||
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||
}
|
||||
|
||||
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
if am.GetNameServerGroupFunc != nil {
|
||||
|
||||
@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
// TODO: use firewall rules
|
||||
aclPeers, _ := account.getPeersByPolicy(peer.ID)
|
||||
aclPeers := account.getPeersByACL(peer.ID)
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
@@ -816,7 +816,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*Pee
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.getPeersByPolicy(p.ID)
|
||||
aclPeers := account.getPeersByACL(p.ID)
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peerID {
|
||||
return peer, nil
|
||||
@@ -833,6 +833,98 @@ func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) *Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
// GetPeerRules returns a list of source or destination rules of a given peer.
|
||||
func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) {
|
||||
// Rules are group based so there is no direct access to peers.
|
||||
// First, find all groups that the given peer belongs to
|
||||
peerGroups := make(map[string]struct{})
|
||||
|
||||
for s, group := range a.Groups {
|
||||
for _, peer := range group.Peers {
|
||||
if peerID == peer {
|
||||
peerGroups[s] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second, find all rules that have discovered source and destination groups
|
||||
srcRulesMap := make(map[string]*Rule)
|
||||
dstRulesMap := make(map[string]*Rule)
|
||||
for _, rule := range a.Rules {
|
||||
for _, g := range rule.Source {
|
||||
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
|
||||
srcRules = append(srcRules, rule)
|
||||
srcRulesMap[rule.ID] = rule
|
||||
}
|
||||
}
|
||||
for _, g := range rule.Destination {
|
||||
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
|
||||
dstRules = append(dstRules, rule)
|
||||
dstRulesMap[rule.ID] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return srcRules, dstRules
|
||||
}
|
||||
|
||||
// getPeersByACL returns all peers that given peer has access to.
|
||||
func (a *Account) getPeersByACL(peerID string) []*Peer {
|
||||
var peers []*Peer
|
||||
srcRules, dstRules := a.GetPeerRules(peerID)
|
||||
|
||||
groups := map[string]*Group{}
|
||||
for _, r := range srcRules {
|
||||
if r.Disabled {
|
||||
continue
|
||||
}
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Destination {
|
||||
if group, ok := a.Groups[gid]; ok {
|
||||
groups[gid] = group
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range dstRules {
|
||||
if r.Disabled {
|
||||
continue
|
||||
}
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Source {
|
||||
if group, ok := a.Groups[gid]; ok {
|
||||
groups[gid] = group
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peersSet := make(map[string]struct{})
|
||||
for _, g := range groups {
|
||||
for _, pid := range g.Peers {
|
||||
peer, ok := a.Peers[pid]
|
||||
if !ok {
|
||||
log.Warnf(
|
||||
"peer %s found in group %s but doesn't belong to account %s",
|
||||
pid,
|
||||
g.ID,
|
||||
a.Id,
|
||||
)
|
||||
continue
|
||||
}
|
||||
// exclude original peer
|
||||
if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok {
|
||||
peersSet[peer.ID] = struct{}{}
|
||||
peers = append(peers, peer.Copy())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peers
|
||||
}
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
||||
|
||||
@@ -136,6 +136,8 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
// TODO: disable until we start use policy again
|
||||
t.Skip()
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
@@ -42,8 +42,11 @@ type UserRole string
|
||||
|
||||
// User represents a user of the system
|
||||
type User struct {
|
||||
Id string
|
||||
Role UserRole
|
||||
Id string
|
||||
Role UserRole
|
||||
IsServiceUser bool
|
||||
// ServiceUserName is only set if IsServiceUser is true
|
||||
ServiceUserName string
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string
|
||||
PATs map[string]*PersonalAccessToken
|
||||
@@ -63,12 +66,13 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
|
||||
if userData == nil {
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: "",
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
@@ -81,12 +85,13 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -101,34 +106,88 @@ func (u *User) Copy() *User {
|
||||
pats[k] = patCopy
|
||||
}
|
||||
return &User{
|
||||
Id: u.Id,
|
||||
Role: u.Role,
|
||||
AutoGroups: autoGroups,
|
||||
PATs: pats,
|
||||
Id: u.Id,
|
||||
Role: u.Role,
|
||||
AutoGroups: autoGroups,
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole) *User {
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, serviceUserName string, autoGroups []string) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Role: role,
|
||||
AutoGroups: []string{},
|
||||
Id: id,
|
||||
Role: role,
|
||||
IsServiceUser: isServiceUser,
|
||||
ServiceUserName: serviceUserName,
|
||||
AutoGroups: autoGroups,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleAdmin
|
||||
// NewRegularUser creates a new user with role UserRoleUser
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser)
|
||||
return NewUser(id, UserRoleUser, false, "", []string{})
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin)
|
||||
return NewUser(id, UserRoleAdmin, false, "", []string{})
|
||||
}
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
if executingUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only admins can create service users")
|
||||
}
|
||||
|
||||
newUserID := uuid.New().String()
|
||||
newUser := NewUser(newUserID, role, true, serviceUserName, autoGroups)
|
||||
log.Debugf("New User: %v", newUser)
|
||||
account.Users[newUserID] = newUser
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": newUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||
|
||||
return &UserInfo{
|
||||
ID: newUser.Id,
|
||||
Email: "",
|
||||
Name: newUser.ServiceUserName,
|
||||
Role: string(newUser.Role),
|
||||
AutoGroups: newUser.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user under the given account. Effectively this is a user invite.
|
||||
func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *UserInfo) (*UserInfo, error) {
|
||||
if user.IsServiceUser {
|
||||
return am.createServiceUser(accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.AutoGroups)
|
||||
}
|
||||
return am.inviteNewUser(accountID, userID, user)
|
||||
}
|
||||
|
||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -193,8 +252,48 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us
|
||||
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user from the given account.
|
||||
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin {
|
||||
return status.Errorf(status.PermissionDenied, "only admins can delete service users")
|
||||
}
|
||||
|
||||
if !targetUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "regular users can not be deleted")
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -206,21 +305,26 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
|
||||
}
|
||||
|
||||
if executingUserID != targetUserId {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser := account.Users[targetUserId]
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
||||
}
|
||||
|
||||
pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id)
|
||||
executingUser := account.Users[executingUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
||||
}
|
||||
@@ -232,8 +336,8 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
||||
return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name}
|
||||
am.storeEvent(executingUserID, targetUserId, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
|
||||
return pat, nil
|
||||
}
|
||||
@@ -243,21 +347,26 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
executingUser := account.Users[executingUserID]
|
||||
if targetUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||
}
|
||||
|
||||
pat := targetUser.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
@@ -271,10 +380,10 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name}
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||
|
||||
delete(user.PATs, tokenID)
|
||||
delete(targetUser.PATs, tokenID)
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
@@ -288,21 +397,26 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
executingUser := account.Users[executingUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
||||
}
|
||||
|
||||
pat := targetUser.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
@@ -315,22 +429,27 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
var pats []*PersonalAccessToken
|
||||
for _, pat := range user.PATs {
|
||||
for _, pat := range targetUser.PATs {
|
||||
pats = append(pats, pat)
|
||||
}
|
||||
|
||||
@@ -386,7 +505,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID})
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
@@ -397,14 +516,14 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID})
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
||||
userData, err := am.lookupUserInCache(newUser.Id, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -454,14 +573,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// IsUserAdmin flag for current user authenticated by JWT token
|
||||
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
account, _, err := am.GetAccountFromToken(claims)
|
||||
// GetAccountByUserID returns an existing account for a given user id
|
||||
func (am *DefaultAccountManager) GetAccountByUserID(userID string) (*Account, error) {
|
||||
return am.Store.GetAccountByUser(userID)
|
||||
}
|
||||
|
||||
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
|
||||
func (am *DefaultAccountManager) IsUserAdmin(userID string) (bool, error) {
|
||||
account, err := am.GetAccountByUserID(userID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get account: %v", err)
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
user, ok := account.Users[userID]
|
||||
if !ok {
|
||||
return false, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
@@ -486,7 +610,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
if !isNil(am.idpManager) {
|
||||
users := make(map[string]struct{}, len(account.Users))
|
||||
for _, user := range account.Users {
|
||||
users[user.Id] = struct{}{}
|
||||
if !user.IsServiceUser {
|
||||
users[user.Id] = struct{}{}
|
||||
}
|
||||
}
|
||||
queriedUsers, err = am.lookupCache(users, accountID)
|
||||
if err != nil {
|
||||
@@ -512,20 +638,44 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
return userInfos, nil
|
||||
}
|
||||
|
||||
for _, queriedUser := range queriedUsers {
|
||||
if !user.IsAdmin() && user.Id != queriedUser.ID {
|
||||
for _, localUser := range account.Users {
|
||||
if !user.IsAdmin() && user.Id != localUser.Id {
|
||||
// if user is not an admin then show only current user and do not show other users
|
||||
continue
|
||||
}
|
||||
if localUser, contains := account.Users[queriedUser.ID]; contains {
|
||||
|
||||
info, err := localUser.toUserInfo(queriedUser)
|
||||
var info *UserInfo
|
||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||
info, err = localUser.toUserInfo(queriedUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userInfos = append(userInfos, info)
|
||||
} else {
|
||||
name := ""
|
||||
if localUser.IsServiceUser {
|
||||
name = localUser.ServiceUserName
|
||||
}
|
||||
info = &UserInfo{
|
||||
ID: localUser.Id,
|
||||
Email: "",
|
||||
Name: name,
|
||||
Role: string(localUser.Role),
|
||||
AutoGroups: localUser.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: localUser.IsServiceUser,
|
||||
}
|
||||
}
|
||||
userInfos = append(userInfos, info)
|
||||
}
|
||||
|
||||
return userInfos, nil
|
||||
}
|
||||
|
||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||
for _, user := range userData {
|
||||
if user.ID == userID {
|
||||
return user, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -1,25 +1,32 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
const (
|
||||
mockAccountID = "accountID"
|
||||
mockUserID = "userID"
|
||||
mockTargetUserId = "targetUserID"
|
||||
mockTokenID1 = "tokenID1"
|
||||
mockToken1 = "SoMeHaShEdToKeN1"
|
||||
mockTokenID2 = "tokenID2"
|
||||
mockToken2 = "SoMeHaShEdToKeN2"
|
||||
mockTokenName = "tokenName"
|
||||
mockEmptyTokenName = ""
|
||||
mockExpiresIn = 7
|
||||
mockWrongExpiresIn = 4506
|
||||
mockAccountID = "accountID"
|
||||
mockUserID = "userID"
|
||||
mockServiceUserID = "serviceUserID"
|
||||
mockRole = "user"
|
||||
mockServiceUserName = "serviceUserName"
|
||||
mockTargetUserId = "targetUserID"
|
||||
mockTokenID1 = "tokenID1"
|
||||
mockToken1 = "SoMeHaShEdToKeN1"
|
||||
mockTokenID2 = "tokenID2"
|
||||
mockToken2 = "SoMeHaShEdToKeN2"
|
||||
mockTokenName = "tokenName"
|
||||
mockEmptyTokenName = ""
|
||||
mockExpiresIn = 7
|
||||
mockWrongExpiresIn = 4506
|
||||
)
|
||||
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
@@ -41,6 +48,8 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
||||
|
||||
fileStore := am.Store.(*FileStore)
|
||||
tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken]
|
||||
|
||||
@@ -60,7 +69,10 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: false,
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
@@ -75,6 +87,31 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
assert.Errorf(t, err, "Creating PAT for different user should thorw error")
|
||||
}
|
||||
|
||||
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: true,
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
||||
}
|
||||
|
||||
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
@@ -207,3 +244,300 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
|
||||
assert.Equal(t, 2, len(pats))
|
||||
}
|
||||
|
||||
func TestUser_Copy(t *testing.T) {
|
||||
// this is an imaginary case which will never be in DB this way
|
||||
user := User{
|
||||
Id: "userId",
|
||||
Role: "role",
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "servicename",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"pat1": {
|
||||
ID: "pat1",
|
||||
Name: "First PAT",
|
||||
HashedToken: "SoMeHaShEdToKeN",
|
||||
ExpirationDate: time.Now().AddDate(0, 0, 7),
|
||||
CreatedBy: "userId",
|
||||
CreatedAt: time.Now(),
|
||||
LastUsed: time.Now(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := validateStruct(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Test needs update: dummy struct has not all fields set : %s", err)
|
||||
}
|
||||
|
||||
copiedUser := user.Copy()
|
||||
|
||||
assert.True(t, cmp.Equal(user, *copiedUser))
|
||||
}
|
||||
|
||||
// based on https://medium.com/@anajankow/fast-check-if-all-struct-fields-are-set-in-golang-bba1917213d2
|
||||
func validateStruct(s interface{}) (err error) {
|
||||
|
||||
structType := reflect.TypeOf(s)
|
||||
structVal := reflect.ValueOf(s)
|
||||
fieldNum := structVal.NumField()
|
||||
|
||||
for i := 0; i < fieldNum; i++ {
|
||||
field := structVal.Field(i)
|
||||
fieldName := structType.Field(i).Name
|
||||
|
||||
isSet := field.IsValid() && !field.IsZero()
|
||||
|
||||
if !isSet {
|
||||
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func TestUser_CreateServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
user, err := am.createServiceUser(mockAccountID, mockUserID, mockRole, mockServiceUserName, []string{"group1", "group2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating service user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
|
||||
assert.NotNil(t, store.Accounts[mockAccountID].Users[user.ID])
|
||||
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
|
||||
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
|
||||
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
|
||||
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
|
||||
assert.Equal(t, map[string]*PersonalAccessToken{}, store.Accounts[mockAccountID].Users[user.ID].PATs)
|
||||
|
||||
assert.Zero(t, user.Email)
|
||||
assert.True(t, user.IsServiceUser)
|
||||
assert.Equal(t, "active", user.Status)
|
||||
}
|
||||
|
||||
func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
user, err := am.CreateUser(mockAccountID, mockUserID, &UserInfo{
|
||||
Name: mockServiceUserName,
|
||||
Role: mockRole,
|
||||
IsServiceUser: true,
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating user: %s", err)
|
||||
}
|
||||
|
||||
assert.True(t, user.IsServiceUser)
|
||||
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
|
||||
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
|
||||
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
|
||||
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
|
||||
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
|
||||
|
||||
assert.Equal(t, mockServiceUserName, user.Name)
|
||||
assert.Equal(t, mockRole, user.Role)
|
||||
assert.Equal(t, []string{"group1", "group2"}, user.AutoGroups)
|
||||
assert.Equal(t, "active", user.Status)
|
||||
}
|
||||
|
||||
func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
_, err = am.CreateUser(mockAccountID, mockUserID, &UserInfo{
|
||||
Name: mockServiceUserName,
|
||||
Role: mockRole,
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
})
|
||||
|
||||
assert.Errorf(t, err, "Not configured IDP will throw error but right path used")
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: mockServiceUserName,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
err = am.DeleteUser(mockAccountID, mockUserID, mockServiceUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when deleting user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(store.Accounts[mockAccountID].Users))
|
||||
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
|
||||
|
||||
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
||||
}
|
||||
|
||||
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
ok, err := am.IsUserAdmin(mockUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
ok, err := am.IsUserAdmin(mockUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
users, err := am.GetUsersFromAccount(mockAccountID, mockUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting users from account: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(users))
|
||||
}
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
users, err := am.GetUsersFromAccount(mockAccountID, mockServiceUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting users from account: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(users))
|
||||
assert.Equal(t, mockServiceUserID, users[0].ID)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user