Compare commits

...

9 Commits

Author SHA1 Message Date
pascal-fischer
6fec0c682e Merging full service user feature into main (#819)
Merging full feature branch into main.
Adding full support for service users including backend objects, persistence, verification and api endpoints.
2023-04-22 12:57:51 +02:00
Chinmay Pai
c2e90a2a97 feat: add support for custom device hostname (#789)
Configure via --hostname (or -n) flag in the `up` and `login` commands
---------

Signed-off-by: Chinmay D. Pai <chinmay.pai@zerodha.com>
2023-04-20 16:00:22 +02:00
Maycon Santos
118880b6f7 Send a status notification on offline peers change (#821)
Sum offline peers too
2023-04-20 15:59:07 +02:00
Zoltan Papp
bb147c2a7c Remove unnecessary uapi open (#807)
Remove unnecessary uapi open from Android implementation
2023-04-17 11:50:12 +02:00
Zoltan Papp
4616bc5258 Add route management for Android interface (#801)
Support client route management feature on Android
2023-04-17 11:15:37 +02:00
Zoltan Papp
1803cf3678 Fix error handling in case of the port is in used (#810) 2023-04-14 16:18:00 +02:00
Zoltan Papp
9f35a7fb8d Ignore ipv6 labeled address (#809)
Ignore ipv6 labeled address
2023-04-14 15:40:27 +02:00
Misha Bragin
2eeed55c18 Bind implementation (#779)
This PR adds supports for the WireGuard userspace implementation
using Bind interface from wireguard-go. 
The newly introduced ICEBind struct implements Bind with UDPMux-based
structs from pion/ice to handle hole punching using ICE.
The core implementation was taken from StdBind of wireguard-go.

The result is a single WireGuard port that is used for host and server reflexive candidates. 
Relay candidates are still handled separately and will be integrated in the following PRs.

ICEBind checks the incoming packets for being STUN or WireGuard ones
and routes them to UDPMux (to handle hole punching) or to WireGuard  respectively.
2023-04-13 17:00:01 +02:00
Givi Khojanashvili
0343c5f239 Rollback simple ACL rules processing. (#803) 2023-04-12 09:39:17 +02:00
99 changed files with 3540 additions and 955 deletions

View File

@@ -6,6 +6,10 @@ on:
- main - main
pull_request: pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
test: test:
runs-on: macos-latest runs-on: macos-latest

View File

@@ -6,6 +6,10 @@ on:
- main - main
pull_request: pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
test: test:
strategy: strategy:
@@ -66,7 +70,7 @@ jobs:
run: go mod tidy run: go mod tidy
- name: Generate Iface Test bin - name: Generate Iface Test bin
run: go test -c -o iface-testing.bin ./iface/... run: go test -c -o iface-testing.bin ./iface/
- name: Generate RouteManager Test bin - name: Generate RouteManager Test bin
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/... run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...

View File

@@ -6,47 +6,45 @@ on:
- main - main
pull_request: pull_request:
env:
downloadPath: '${{ github.workspace }}\temp'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
pre:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- run: bash -x wireguard_nt.sh
working-directory: client
- uses: actions/upload-artifact@v2
with:
name: syso
path: client/*.syso
retention-days: 1
test: test:
needs: pre
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
id: go
with: with:
go-version: 1.19.x go-version: 1.19.x
- uses: actions/cache@v2 - name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with: with:
path: | file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
%LocalAppData%\go-build file-name: wintun.zip
~\go\pkg\mod location: ${{ env.downloadPath }}
~\AppData\Local\go-build sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- uses: actions/download-artifact@v2 - name: Decompressing wintun files
with: run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
name: syso
path: iface\
- name: Test - run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
run: go test -tags=load_wgnt_from_rsrc -timeout 5m -p 1 ./...
- run: choco install -y sysinternals
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -1,5 +1,8 @@
name: golangci-lint name: golangci-lint
on: [pull_request] on: [pull_request]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
golangci: golangci:
name: lint name: lint

View File

@@ -7,7 +7,9 @@ on:
pull_request: pull_request:
paths: paths:
- "release_files/install.sh" - "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
install-cli-only: install-cli-only:
runs-on: macos-latest runs-on: macos-latest

View File

@@ -7,7 +7,9 @@ on:
pull_request: pull_request:
paths: paths:
- "release_files/install.sh" - "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
install-cli-only: install-cli-only:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -9,9 +9,13 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.5" SIGN_PIPE_VER: "v0.0.6"
GORELEASER_VER: "v1.14.1" GORELEASER_VER: "v1.14.1"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -21,10 +25,6 @@ jobs:
uses: actions/checkout@v2 uses: actions/checkout@v2
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Generate syso with DLL
run: bash -x wireguard_nt.sh
working-directory: client
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
@@ -59,6 +59,17 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc amd64
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
- -
name: Run GoReleaser name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2 uses: goreleaser/goreleaser-action@v2

View File

@@ -6,6 +6,10 @@ on:
- main - main
pull_request: pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -26,7 +26,7 @@ type TunAdapter interface {
// IFaceDiscover export internal IFaceDiscover for mobile // IFaceDiscover export internal IFaceDiscover for mobile
type IFaceDiscover interface { type IFaceDiscover interface {
stdnet.IFaceDiscover stdnet.ExternalIFaceDiscover
} }
func init() { func init() {

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
) )
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
@@ -32,6 +33,11 @@ var loginCmd = &cobra.Command{
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(context.Background())
if hostName != "" {
// nolint
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
// workaround to run without service // workaround to run without service
if logFile == "console" { if logFile == "console" {
err = handleRebrand(cmd) err = handleRebrand(cmd)

View File

@@ -45,6 +45,7 @@ var (
managementURL string managementURL string
adminURL string adminURL string
setupKey string setupKey string
hostName string
preSharedKey string preSharedKey string
natExternalIPs []string natExternalIPs []string
customDNSAddress string customDNSAddress string
@@ -94,6 +95,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.AddCommand(serviceCmd) rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd) rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd) rootCmd.AddCommand(downCmd)

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -55,6 +56,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
if hostName != "" {
// nolint
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
if foregroundMode { if foregroundMode {
return runInForegroundMode(ctx, cmd) return runInForegroundMode(ctx, cmd)
} }

View File

@@ -193,6 +193,7 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe`
Sleep 3000 Sleep 3000
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
SetShellVarContext current SetShellVarContext current

View File

@@ -27,7 +27,7 @@ const (
) )
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-"} "Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
// ConfigInput carries configuration changes to the client // ConfigInput carries configuration changes to the client
type ConfigInput struct { type ConfigInput struct {

View File

@@ -23,7 +23,7 @@ import (
) )
// RunClient with main logic. // RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) error { func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@@ -108,7 +108,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
localPeerState := peer.LocalPeerState{ localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(), PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(), KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(), FQDN: loginResp.GetPeerConfig().GetFqdn(),
} }
@@ -144,13 +144,19 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover) engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, statusRecorder) md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
if err != nil {
log.Error(err)
return wrapErr(err)
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
err = engine.Start() err = engine.Start()
if err != nil { if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
@@ -194,13 +200,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
} }
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
engineConf := &EngineConfig{ engineConf := &EngineConfig{
WgIfaceName: config.WgIface, WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address, WgAddr: peerConfig.Address,
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
IFaceBlackList: config.IFaceBlackList, IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery, DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key, WgPrivateKey: key,

View File

@@ -9,7 +9,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/miekg/dns" "github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@@ -199,7 +202,11 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil) newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -20,7 +20,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -47,10 +46,6 @@ var ErrResetConnection = fmt.Errorf("reset connection")
type EngineConfig struct { type EngineConfig struct {
WgPort int WgPort int
WgIfaceName string WgIfaceName string
// TunAdapter is option. It is necessary for mobile version.
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.IFaceDiscover
// WgAddr is a Wireguard local address (Netbird Network IP) // WgAddr is a Wireguard local address (Netbird Network IP)
WgAddr string WgAddr string
@@ -91,6 +86,8 @@ type Engine struct {
syncMsgMux *sync.Mutex syncMsgMux *sync.Mutex
config *EngineConfig config *EngineConfig
mobileDep MobileDependency
// STUNs is a list of STUN servers used by ICE // STUNs is a list of STUN servers used by ICE
STUNs []*ice.URL STUNs []*ice.URL
// TURNs is a list of STUN servers used by ICE // TURNs is a list of STUN servers used by ICE
@@ -130,7 +127,7 @@ type Peer struct {
func NewEngine( func NewEngine(
ctx context.Context, cancel context.CancelFunc, ctx context.Context, cancel context.CancelFunc,
signalClient signal.Client, mgmClient mgm.Client, signalClient signal.Client, mgmClient mgm.Client,
config *EngineConfig, statusRecorder *peer.Status, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
) *Engine { ) *Engine {
return &Engine{ return &Engine{
ctx: ctx, ctx: ctx,
@@ -140,6 +137,7 @@ func NewEngine(
peerConns: make(map[string]*peer.Conn), peerConns: make(map[string]*peer.Conn),
syncMsgMux: &sync.Mutex{}, syncMsgMux: &sync.Mutex{},
config: config, config: config,
mobileDep: mobileDep,
STUNs: []*ice.URL{}, STUNs: []*ice.URL{},
TURNs: []*ice.URL{}, TURNs: []*ice.URL{},
networkSerial: 0, networkSerial: 0,
@@ -166,34 +164,56 @@ func (e *Engine) Stop() error {
return nil return nil
} }
// Start creates a new Wireguard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error { func (e *Engine) Start() error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
wgIfaceName := e.config.WgIfaceName wgIFaceName := e.config.WgIfaceName
wgAddr := e.config.WgAddr wgAddr := e.config.WgAddr
myPrivateKey := e.config.WgPrivateKey myPrivateKey := e.config.WgPrivateKey
var err error var err error
transportNet, err := e.newStdNet()
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter)
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error()) log.Errorf("failed to create pion's stdnet: %s", err)
}
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
return err return err
} }
err = e.wgInterface.Create()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
e.close()
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
e.close()
return err
}
if e.wgInterface.IsUserspaceBind() {
iceBind := e.wgInterface.GetBind()
udpMux, err := iceBind.GetICEMux()
if err != nil {
e.close()
return err
}
e.udpMux = udpMux.UDPMuxDefault
e.udpMuxSrflx = udpMux
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
} else {
networkName := "udp" networkName := "udp"
if e.config.DisableIPv6Discovery { if e.config.DisableIPv6Discovery {
networkName = "udp4" networkName = "udp4"
} }
transportNet, err := e.newStdNet()
if err != nil {
log.Warnf("failed to create pion's stdnet: %s", err)
}
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort}) e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil { if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error()) log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
@@ -213,19 +233,6 @@ func (e *Engine) Start() error {
return err return err
} }
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet}) e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
err = e.wgInterface.Create()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
e.close()
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
e.close()
return err
} }
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
@@ -496,7 +503,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{ e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
IP: e.config.WgAddr, IP: e.config.WgAddr,
PubKey: e.config.WgPrivateKey.PublicKey().String(), PubKey: e.config.WgPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(), KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: conf.GetFqdn(), FQDN: conf.GetFqdn(),
}) })
@@ -822,9 +829,10 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
ProxyConfig: proxyConfig, ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(),
} }
peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover) peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1006,12 +1014,6 @@ func (e *Engine) close() {
} }
} }
if e.udpMuxSrflx != nil {
if err := e.udpMuxSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux: %v", err)
}
}
if e.udpMuxConn != nil { if e.udpMuxConn != nil {
if err := e.udpMuxConn.Close(); err != nil { if err := e.udpMuxConn.Close(); err != nil {
log.Debugf("close udp mux connection: %v", err) log.Debugf("close udp mux connection: %v", err)

View File

@@ -3,9 +3,9 @@
package internal package internal
import ( import (
"github.com/pion/transport/v2/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
func (e *Engine) newStdNet() (*stdnet.Net, error) { func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet() return stdnet.NewNet(e.config.IFaceBlackList)
} }

View File

@@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) { func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.config.IFaceDiscover) return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
} }

View File

@@ -3,6 +3,8 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface/bind"
"github.com/pion/transport/v2/stdnet"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@@ -72,7 +74,7 @@ func TestEngine_SSH(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -206,12 +208,24 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet)
if err != nil {
t.Fatal(err)
}
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }
conn, err := net.ListenUDP("udp4", nil)
if err != nil {
t.Fatal(err)
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
type testCase struct { type testCase struct {
name string name string
@@ -390,7 +404,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -548,8 +562,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@@ -713,8 +731,12 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
@@ -978,7 +1000,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort, WgPort: wgPort,
} }
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
} }
func startSignal() (*grpc.Server, string, error) { func startSignal() (*grpc.Server, string, error) {

View File

@@ -0,0 +1,13 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
)
// MobileDependency collect all dependencies for mobile platform
type MobileDependency struct {
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
Routes []string
}

View File

@@ -0,0 +1,29 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: ifaceDiscover,
}
err := md.readMap(mgmClient)
return md, err
}
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
routes, err := mgmClient.GetRoutes()
if err != nil {
return err
}
d.Routes = make([]string, len(routes))
for i, r := range routes {
d.Routes[i] = r.GetNetwork()
}
return nil
}

View File

@@ -0,0 +1,13 @@
//go:build !android
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
return MobileDependency{}, nil
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -46,6 +45,9 @@ type ConnConfig struct {
LocalWgPort int LocalWgPort int
NATExternalIPs []string NATExternalIPs []string
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
UserspaceBind bool
} }
// OfferAnswer represents a session establishment offer or answer // OfferAnswer represents a session establishment offer or answer
@@ -95,7 +97,7 @@ type Conn struct {
meta meta meta meta
adapter iface.TunAdapter adapter iface.TunAdapter
iFaceDiscover stdnet.IFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
} }
// meta holds meta information about a connection // meta holds meta information about a connection
@@ -121,7 +123,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*Conn, error) { func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
return &Conn{ return &Conn{
config: config, config: config,
mu: sync.Mutex{}, mu: sync.Mutex{},
@@ -136,32 +138,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter
}, nil }, nil
} }
// interfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them
func interfaceFilter(blackList []string) func(string) bool {
return func(iFace string) bool {
for _, s := range blackList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
return false
}
}
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
return true
}
defer func() {
_ = wg.Close()
}()
_, err = wg.Device(iFace)
return err != nil
}
}
func (conn *Conn) reCreateAgent() error { func (conn *Conn) reCreateAgent() error {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -171,7 +147,7 @@ func (conn *Conn) reCreateAgent() error {
var err error var err error
transportNet, err := conn.newStdNet() transportNet, err := conn.newStdNet()
if err != nil { if err != nil {
log.Warnf("failed to create pion's stdnet: %s", err) log.Errorf("failed to create pion's stdnet: %s", err)
} }
agentConfig := &ice.AgentConfig{ agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled, MulticastDNSMode: ice.MulticastDNSModeDisabled,
@@ -179,7 +155,7 @@ func (conn *Conn) reCreateAgent() error {
Urls: conn.config.StunTurn, Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}, CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
FailedTimeout: &failedTimeout, FailedTimeout: &failedTimeout,
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList), InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux, UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx, UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs, NAT1To1IPs: conn.config.NATExternalIPs,
@@ -319,7 +295,7 @@ func (conn *Conn) Open() error {
return err return err
} }
if conn.proxy.Type() == proxy.TypeNoProxy { if conn.proxy.Type() == proxy.TypeDirectNoProxy {
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String()) host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String()) rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection // direct Wireguard connection
@@ -341,29 +317,62 @@ func (conn *Conn) Open() error {
// useProxy determines whether a direct connection (without a go proxy) is possible // useProxy determines whether a direct connection (without a go proxy) is possible
// //
// There are 2 cases: // There are 3 cases:
// //
// * When neither candidate is from hard nat and one of the peers has a public IP // * When neither candidate is from hard nat and one of the peers has a public IP
// //
// * both peers are in the same private network // * both peers are in the same private network
// //
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
//
// Please note, that this check happens when peers were already able to ping each other using ICE layer. // Please note, that this check happens when peers were already able to ping each other using ICE layer.
func shouldUseProxy(pair *ice.CandidatePair) bool { func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
if !isRelayCandidate(pair.Local) && userspaceBind {
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
return false
}
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) { if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
return false return false
} }
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) { if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
log.Debugf("shouldn't use proxy because the remote peer is not behind a hard NAT and the local one has a public IP")
return false return false
} }
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) { if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network")
return false
}
if (isPeerReflexiveCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) ||
isHostCandidateWithPrivateIP(pair.Local) && isPeerReflexiveCandidateWithPrivateIP(pair.Remote)) && isSameNetworkPrefix(pair) {
log.Debugf("shouldn't use proxy because peers are in the same private /16 network and one peer is peer reflexive")
return false return false
} }
return true return true
} }
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
localIP := net.ParseIP(pair.Local.Address())
remoteIP := net.ParseIP(pair.Remote.Address())
if localIP == nil || remoteIP == nil {
return false
}
// only consider /16 networks
mask := net.IPMask{255, 255, 0, 0}
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func isHardNATCandidate(candidate ice.Candidate) bool { func isHardNATCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
} }
@@ -376,9 +385,13 @@ func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address()) return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
} }
func isPeerReflexiveCandidateWithPrivateIP(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypePeerReflexive && !isPublicIP(candidate.Address())
}
func isPublicIP(address string) bool { func isPublicIP(address string) bool {
ip := net.ParseIP(address) ip := net.ParseIP(address)
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() { if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
return false return false
} }
return true return true
@@ -412,7 +425,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true peerState.Relayed = true
} }
peerState.Direct = p.Type() == proxy.TypeNoProxy peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
@@ -423,8 +436,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
} }
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy { func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
useProxy := shouldUseProxy(pair)
localDirectMode := !useProxy localDirectMode := !useProxy
remoteDirectMode := localDirectMode remoteDirectMode := localDirectMode
@@ -434,13 +446,16 @@ func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgP
remoteDirectMode = conn.receiveRemoteDirectMode() remoteDirectMode = conn.receiveRemoteDirectMode()
} }
if conn.config.UserspaceBind && localDirectMode {
return proxy.NewNoProxy(conn.config.ProxyConfig)
}
if localDirectMode && remoteDirectMode { if localDirectMode && remoteDirectMode {
log.Debugf("using WireGuard direct mode with peer %s", conn.config.Key) return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
return proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
} }
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key) log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
return proxy.NewWireguardProxy(conn.config.ProxyConfig) return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
} }
func (conn *Conn) sendLocalDirectMode(localMode bool) { func (conn *Conn) sendLocalDirectMode(localMode bool) {

View File

@@ -5,6 +5,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@@ -28,7 +30,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts", ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"} "Tailscale", "tailscale"}
filter := interfaceFilter(ignore) filter := stdnet.InterfaceFilter(ignore)
for _, s := range ignore { for _, s := range ignore {
assert.Equal(t, filter(s), false) assert.Equal(t, filter(s), false)
@@ -208,6 +210,7 @@ func TestConn_ShouldUseProxy(t *testing.T) {
return ice.CandidateTypeHost return ice.CandidateTypeHost
}, },
} }
srflxCandidate := &mockICECandidate{ srflxCandidate := &mockICECandidate{
AddressFunc: func() string { AddressFunc: func() string {
return "1.1.1.1" return "1.1.1.1"
@@ -320,11 +323,47 @@ func TestConn_ShouldUseProxy(t *testing.T) {
}, },
expected: false, expected: false,
}, },
{
name: "Don't Use Proxy When Both Candidates are in private network and one is peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: false,
},
{
name: "Should Use Proxy When Both Candidates are in private network and both are peer reflexive",
candatePair: &ice.CandidatePair{
Local: &mockICECandidate{AddressFunc: func() string {
return "10.16.102.168"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
Remote: &mockICECandidate{AddressFunc: func() string {
return "10.16.101.96"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypePeerReflexive
}},
},
expected: true,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
result := shouldUseProxy(testCase.candatePair) result := shouldUseProxy(testCase.candatePair, false)
if result != testCase.expected { if result != testCase.expected {
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result) t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
} }
@@ -365,7 +404,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: true, inputDirectModeSupport: true,
inputRemoteModeMessage: true, inputRemoteModeMessage: true,
expected: proxy.TypeWireguard, expected: proxy.TypeWireGuard,
}, },
{ {
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy", name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
@@ -375,7 +414,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: true, inputDirectModeSupport: true,
inputRemoteModeMessage: false, inputRemoteModeMessage: false,
expected: proxy.TypeWireguard, expected: proxy.TypeWireGuard,
}, },
{ {
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy", name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
@@ -385,7 +424,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: false, inputDirectModeSupport: false,
inputRemoteModeMessage: false, inputRemoteModeMessage: false,
expected: proxy.TypeWireguard, expected: proxy.TypeWireGuard,
}, },
{ {
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy", name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
@@ -395,7 +434,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: false, inputDirectModeSupport: false,
inputRemoteModeMessage: false, inputRemoteModeMessage: false,
expected: proxy.TypeNoProxy, expected: proxy.TypeDirectNoProxy,
}, },
{ {
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy", name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
@@ -405,7 +444,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
}, },
inputDirectModeSupport: true, inputDirectModeSupport: true,
inputRemoteModeMessage: true, inputRemoteModeMessage: true,
expected: proxy.TypeNoProxy, expected: proxy.TypeDirectNoProxy,
}, },
} }
for _, testCase := range testCases { for _, testCase := range testCases {

View File

@@ -78,6 +78,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
defer d.mux.Unlock() defer d.mux.Unlock()
d.offlinePeers = make([]State, len(replacement)) d.offlinePeers = make([]State, len(replacement))
copy(d.offlinePeers, replacement) copy(d.offlinePeers, replacement)
d.notifyPeerListChanged()
} }
// AddPeer adds peer to Daemon status map // AddPeer adds peer to Daemon status map
@@ -308,7 +309,7 @@ func (d *Status) onConnectionChanged() {
} }
func (d *Status) notifyPeerListChanged() { func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(len(d.peers)) d.notifier.peerListChanged(len(d.peers) + len(d.offlinePeers))
} }
func (d *Status) notifyAddressChanged() { func (d *Status) notifyAddressChanged() {

View File

@@ -3,9 +3,9 @@
package peer package peer
import ( import (
"github.com/pion/transport/v2/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
func (conn *Conn) newStdNet() (*stdnet.Net, error) { func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet() return stdnet.NewNet(conn.config.InterfaceBlackList)
} }

View File

@@ -3,5 +3,5 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (conn *Conn) newStdNet() (*stdnet.Net, error) { func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(conn.iFaceDiscover) return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
} }

View File

@@ -0,0 +1,57 @@
package proxy
import (
log "github.com/sirupsen/logrus"
"net"
)
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
// This is possible in either of these cases:
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
type DirectNoProxy struct {
config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
}
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
func NewDirectNoProxy(config Config, remoteWgPort int) *DirectNoProxy {
return &DirectNoProxy{config: config, RemoteWgListenPort: remoteWgPort}
}
// Close removes peer from the WireGuard interface
func (p *DirectNoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote IP and default WireGuard port
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using DirectNoProxy while connecting to peer %s", p.config.RemoteKey)
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
// Type returns the type of this proxy
func (p *DirectNoProxy) Type() Type {
return TypeDirectNoProxy
}

View File

@@ -5,24 +5,18 @@ import (
"net" "net"
) )
// NoProxy is used when there is no need for a proxy between ICE and Wireguard. // NoProxy is used just to configure WireGuard without any local proxy in between.
// This is possible in either of these cases: // Used when the WireGuard interface is userspace and uses bind.ICEBind
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// NoProxy will just update remote peer with a remote host and fixed Wireguard port (r.g. 51820).
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
type NoProxy struct { type NoProxy struct {
config Config config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
} }
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port // NewNoProxy creates a new NoProxy with a provided config
func NewNoProxy(config Config, remoteWgPort int) *NoProxy { func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort} return &NoProxy{config: config}
} }
// Close removes peer from the WireGuard interface
func (p *NoProxy) Close() error { func (p *NoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey) err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil { if err != nil {
@@ -31,23 +25,16 @@ func (p *NoProxy) Close() error {
return nil return nil
} }
// Start just updates Wireguard peer with the remote IP and default Wireguard port // Start just updates WireGuard peer with the remote address
func (p *NoProxy) Start(remoteConn net.Conn) error { func (p *NoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using NoProxy while connecting to peer %s", p.config.RemoteKey) log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String()) addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil { if err != nil {
return err return err
} }
addr.Port = p.RemoteWgListenPort return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey) addr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
} }
func (p *NoProxy) Type() Type { func (p *NoProxy) Type() Type {

View File

@@ -13,9 +13,10 @@ const DefaultWgKeepAlive = 25 * time.Second
type Type string type Type string
const ( const (
TypeNoProxy Type = "NoProxy" TypeDirectNoProxy Type = "DirectNoProxy"
TypeWireguard Type = "Wireguard" TypeWireGuard Type = "WireGuard"
TypeDummy Type = "Dummy" TypeDummy Type = "Dummy"
TypeNoProxy Type = "NoProxy"
) )
type Config struct { type Config struct {

View File

@@ -6,8 +6,8 @@ import (
"net" "net"
) )
// WireguardProxy proxies // WireGuardProxy proxies
type WireguardProxy struct { type WireGuardProxy struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@@ -17,13 +17,13 @@ type WireguardProxy struct {
localConn net.Conn localConn net.Conn
} }
func NewWireguardProxy(config Config) *WireguardProxy { func NewWireGuardProxy(config Config) *WireGuardProxy {
p := &WireguardProxy{config: config} p := &WireGuardProxy{config: config}
p.ctx, p.cancel = context.WithCancel(context.Background()) p.ctx, p.cancel = context.WithCancel(context.Background())
return p return p
} }
func (p *WireguardProxy) updateEndpoint() error { func (p *WireGuardProxy) updateEndpoint() error {
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
if err != nil { if err != nil {
return err return err
@@ -38,7 +38,7 @@ func (p *WireguardProxy) updateEndpoint() error {
return nil return nil
} }
func (p *WireguardProxy) Start(remoteConn net.Conn) error { func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
p.remoteConn = remoteConn p.remoteConn = remoteConn
var err error var err error
@@ -60,7 +60,7 @@ func (p *WireguardProxy) Start(remoteConn net.Conn) error {
return nil return nil
} }
func (p *WireguardProxy) Close() error { func (p *WireGuardProxy) Close() error {
p.cancel() p.cancel()
if c := p.localConn; c != nil { if c := p.localConn; c != nil {
err := p.localConn.Close() err := p.localConn.Close()
@@ -77,7 +77,7 @@ func (p *WireguardProxy) Close() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer // proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks // blocks
func (p *WireguardProxy) proxyToRemote() { func (p *WireGuardProxy) proxyToRemote() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
@@ -101,7 +101,7 @@ func (p *WireguardProxy) proxyToRemote() {
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard // proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks // blocks
func (p *WireguardProxy) proxyToLocal() { func (p *WireGuardProxy) proxyToLocal() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
@@ -123,6 +123,6 @@ func (p *WireguardProxy) proxyToLocal() {
} }
} }
func (p *WireguardProxy) Type() Type { func (p *WireGuardProxy) Type() Type {
return TypeWireguard return TypeWireGuard
} }

View File

@@ -1,12 +1,15 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
import "github.com/google/nftables"
const ( const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding" ipv6Forwarding = "netbird-rt-ipv6-forwarding"

View File

@@ -1,14 +1,17 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"net/netip" "net/netip"
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {

View File

@@ -1,10 +1,13 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
) )
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {

View File

@@ -1,9 +1,130 @@
package routemanager package routemanager
import "github.com/netbirdio/netbird/route" import (
"context"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/version"
)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
Stop() Stop()
} }
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRouter *serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRouter: newServerRouter(ctx, wgInterface),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.cleanUp()
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}

View File

@@ -1,31 +0,0 @@
package routemanager
import (
"context"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
// DefaultManager dummy router manager for Android
type DefaultManager struct {
ctx context.Context
serverRouter *serverRouter
wgInterface *iface.WGIface
}
// NewManager returns a new dummy route manager what doing nothing
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
return &DefaultManager{}
}
// UpdateRoutes ...
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
return nil
}
// Stop ...
func (m *DefaultManager) Stop() {
}

View File

@@ -1,186 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/version"
)
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRoutes map[string]*route.Route
serverRouter *serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRoutes: make(map[string]*route.Route),
serverRouter: &serverRouter{
routes: make(map[string]*route.Route),
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
firewall: NewFirewall(ctx),
},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.firewall.CleanRoutingRules()
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.serverRouter.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.serverRoutes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
continue
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.serverRoutes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.serverRoutes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.serverRoutes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.serverRoutes[id] = newRoute
}
if len(m.serverRoutes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.updateServerRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}

View File

@@ -3,6 +3,7 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pion/transport/v2/stdnet"
"net/netip" "net/netip"
"runtime" "runtime"
"testing" "testing"
@@ -391,7 +392,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
@@ -414,7 +420,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if testCase.shouldCheckServerRoutes { if testCase.shouldCheckServerRoutes {
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match") require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
} }
}) })
} }

View File

@@ -1,16 +1,19 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
) )
import "github.com/google/nftables"
const ( const (
nftablesTable = "netbird-rt" nftablesTable = "netbird-rt"

View File

@@ -1,12 +1,15 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"testing"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
) )
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {

View File

@@ -0,0 +1,24 @@
package routemanager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}

View File

@@ -1,67 +0,0 @@
package routemanager
import (
"github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"net/netip"
"sync"
)
type serverRouter struct {
routes map[string]*route.Route
// best effort to keep net forward configuration as it was
netForwardHistoryEnabled bool
mux sync.Mutex
firewall firewallManager
}
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
delete(m.serverRouter.routes, route.ID)
return nil
}
}
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
m.serverRouter.routes[route.ID] = route
return nil
}
}

View File

@@ -0,0 +1,21 @@
package routemanager
import (
"context"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{}
}
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil
}
func (r *serverRouter) cleanUp() {}

View File

@@ -0,0 +1,120 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewallManager
wgInterface *iface.WGIface
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: NewFirewall(ctx),
wgInterface: wgInterface,
}
}
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.routes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.routes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.routes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.routes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.routes[id] = newRoute
}
if len(m.routes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
delete(m.routes, route.ID)
return nil
}
}
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
m.routes[route.ID] = route
return nil
}
}
func (m *serverRouter) cleanUp() {
m.firewall.CleanRoutingRules()
}

View File

@@ -0,0 +1,13 @@
package routemanager
import (
"net/netip"
)
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return nil
}

View File

@@ -1,10 +1,13 @@
//go:build !android
package routemanager package routemanager
import ( import (
"github.com/vishvananda/netlink"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"github.com/vishvananda/netlink"
) )
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
@@ -62,12 +65,3 @@ func enableIPForwarding() error {
err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
return err return err
} }
func isNetForwardHistoryEnabled() bool {
out, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
// todo
panic(err)
}
return string(out) == "1"
}

View File

@@ -1,11 +1,14 @@
//go:build !android
package routemanager package routemanager
import ( import (
"fmt" "fmt"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"net" "net"
"net/netip" "net/netip"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
) )
var errRouteNotFound = fmt.Errorf("route not found") var errRouteNotFound = fmt.Errorf("route not found")

View File

@@ -3,6 +3,7 @@ package routemanager
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net" "net"
"net/netip" "net/netip"
@@ -32,7 +33,11 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()

View File

@@ -4,10 +4,11 @@
package routemanager package routemanager
import ( import (
log "github.com/sirupsen/logrus"
"net/netip" "net/netip"
"os/exec" "os/exec"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
) )
func addToRouteTable(prefix netip.Prefix, addr string) error { func addToRouteTable(prefix netip.Prefix, addr string) error {
@@ -34,8 +35,3 @@ func enableIPForwarding() error {
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil return nil
} }
func isNetForwardHistoryEnabled() bool {
log.Infof("check netforward history is not implemented on %s", runtime.GOOS)
return false
}

View File

@@ -0,0 +1,14 @@
package stdnet
import "github.com/pion/transport/v2"
// ExternalIFaceDiscover provide an option for external services (mobile)
// to collect network interface information
type ExternalIFaceDiscover interface {
// IFaces return with the description of the interfaces
IFaces() (string, error)
}
type iFaceDiscover interface {
iFaces() ([]*transport.Interface, error)
}

View File

@@ -0,0 +1,98 @@
package stdnet
import (
"fmt"
"net"
"strings"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
)
type mobileIFaceDiscover struct {
externalDiscover ExternalIFaceDiscover
}
func newMobileIFaceDiscover(externalDiscover ExternalIFaceDiscover) *mobileIFaceDiscover {
return &mobileIFaceDiscover{
externalDiscover: externalDiscover,
}
}
func (m *mobileIFaceDiscover) iFaces() ([]*transport.Interface, error) {
ifacesString, err := m.externalDiscover.IFaces()
if err != nil {
return nil, err
}
interfaces := m.parseInterfacesString(ifacesString)
return interfaces, nil
}
func (m *mobileIFaceDiscover) parseInterfacesString(interfaces string) []*transport.Interface {
ifs := []*transport.Interface{}
for _, iface := range strings.Split(interfaces, "\n") {
if strings.TrimSpace(iface) == "" {
continue
}
fields := strings.Split(iface, "|")
if len(fields) != 2 {
log.Warnf("parseInterfacesString: unable to split %q", iface)
continue
}
var name string
var index, mtu int
var up, broadcast, loopback, pointToPoint, multicast bool
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
if err != nil {
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
continue
}
newIf := net.Interface{
Name: name,
Index: index,
MTU: mtu,
}
if up {
newIf.Flags |= net.FlagUp
}
if broadcast {
newIf.Flags |= net.FlagBroadcast
}
if loopback {
newIf.Flags |= net.FlagLoopback
}
if pointToPoint {
newIf.Flags |= net.FlagPointToPoint
}
if multicast {
newIf.Flags |= net.FlagMulticast
}
ifc := transport.NewInterface(newIf)
addrs := strings.Trim(fields[1], " \n")
foundAddress := false
for _, addr := range strings.Split(addrs, " ") {
if strings.Contains(addr, "%") {
continue
}
ip, ipNet, err := net.ParseCIDR(addr)
if err != nil {
log.Warnf("%s", err)
continue
}
ipNet.IP = ip
ifc.AddAddress(ipNet)
foundAddress = true
}
if foundAddress {
ifs = append(ifs, ifc)
}
}
return ifs
}

View File

@@ -3,6 +3,8 @@ package stdnet
import ( import (
"fmt" "fmt"
"testing" "testing"
log "github.com/sirupsen/logrus"
) )
func Test_parseInterfacesString(t *testing.T) { func Test_parseInterfacesString(t *testing.T) {
@@ -20,6 +22,7 @@ func Test_parseInterfacesString(t *testing.T) {
{"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"}, {"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"},
{"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"}, {"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"},
{"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"}, {"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"},
{"rmnet_data2", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2%rmnet2/64"},
} }
var exampleString string var exampleString string
@@ -35,11 +38,13 @@ func Test_parseInterfacesString(t *testing.T) {
d.multicast, d.multicast,
d.addr) d.addr)
} }
nets := parseInterfacesString(exampleString) d := mobileIFaceDiscover{}
nets := d.parseInterfacesString(exampleString)
if len(nets) == 0 { if len(nets) == 0 {
t.Fatalf("failed to parse interfaces") t.Fatalf("failed to parse interfaces")
} }
log.Printf("%d", len(nets))
for i, net := range nets { for i, net := range nets {
if net.MTU != testData[i].mtu { if net.MTU != testData[i].mtu {
t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu) t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu)
@@ -58,7 +63,7 @@ func Test_parseInterfacesString(t *testing.T) {
if len(addr) == 0 { if len(addr) == 0 {
t.Errorf("invalid address parsing") t.Errorf("invalid address parsing")
} }
log.Printf("%v", addr)
if addr[0].String() != testData[i].addr { if addr[0].String() != testData[i].addr {
t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr) t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr)
} }

View File

@@ -0,0 +1,36 @@
package stdnet
import (
"net"
"github.com/pion/transport/v2"
)
type pionDiscover struct {
}
func (d pionDiscover) iFaces() ([]*transport.Interface, error) {
ifs := []*transport.Interface{}
oifs, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, oif := range oifs {
ifc := transport.NewInterface(oif)
addrs, err := oif.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
ifc.AddAddress(addr)
}
ifs = append(ifs, ifc)
}
return ifs, nil
}

View File

@@ -0,0 +1,40 @@
package stdnet
import (
"strings"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
)
// InterfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them.
func InterfaceFilter(disallowList []string) func(string) bool {
return func(iFace string) bool {
if strings.HasPrefix(iFace, "lo") {
// hardcoded loopback check to support already installed agents
return false
}
for _, s := range disallowList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
return false
}
}
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
return true
}
defer func() {
_ = wg.Close()
}()
_, err = wg.Device(iFace)
return err != nil
}
}

View File

@@ -1,8 +0,0 @@
package stdnet
// IFaceDiscover provide an option for external services (mobile)
// to collect network interface information
type IFaceDiscover interface {
// IFaces return with the description of the interfaces
IFaces() (string, error)
}

View File

@@ -5,12 +5,9 @@ package stdnet
import ( import (
"fmt" "fmt"
"net"
"strings"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
"github.com/pion/transport/v2/stdnet" "github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
) )
// Net is an implementation of the net.Net interface // Net is an implementation of the net.Net interface
@@ -18,24 +15,40 @@ import (
type Net struct { type Net struct {
stdnet.Net stdnet.Net
interfaces []*transport.Interface interfaces []*transport.Interface
iFaceDiscover iFaceDiscover
// interfaceFilter should return true if the given interfaceName is allowed
interfaceFilter func(interfaceName string) bool
}
// NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
n := &Net{
iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover),
interfaceFilter: InterfaceFilter(disallowList),
}
return n, n.UpdateInterfaces()
} }
// NewNet creates a new StdNet instance. // NewNet creates a new StdNet instance.
func NewNet(iFaceDiscover IFaceDiscover) (*Net, error) { func NewNet(disallowList []string) (*Net, error) {
n := &Net{} n := &Net{
iFaceDiscover: pionDiscover{},
return n, n.UpdateInterfaces(iFaceDiscover) interfaceFilter: InterfaceFilter(disallowList),
}
return n, n.UpdateInterfaces()
} }
// UpdateInterfaces updates the internal list of network interfaces // UpdateInterfaces updates the internal list of network interfaces
// and associated addresses. // and associated addresses filtering them by name.
func (n *Net) UpdateInterfaces(iFaceDiscover IFaceDiscover) error { // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
ifacesString, err := iFaceDiscover.IFaces() // wasn't specified.
func (n *Net) UpdateInterfaces() (err error) {
allIfaces, err := n.iFaceDiscover.iFaces()
if err != nil { if err != nil {
return err return err
} }
n.interfaces = parseInterfacesString(ifacesString) n.interfaces = n.filterInterfaces(allIfaces)
return err return nil
} }
// Interfaces returns a slice of interfaces which are available on the // Interfaces returns a slice of interfaces which are available on the
@@ -70,68 +83,15 @@ func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name) return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name)
} }
func parseInterfacesString(interfaces string) []*transport.Interface { func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.Interface {
ifs := []*transport.Interface{} if n.interfaceFilter == nil {
return interfaces
for _, iface := range strings.Split(interfaces, "\n") {
if strings.TrimSpace(iface) == "" {
continue
} }
result := []*transport.Interface{}
fields := strings.Split(iface, "|") for _, iface := range interfaces {
if len(fields) != 2 { if n.interfaceFilter(iface.Name) {
log.Warnf("parseInterfacesString: unable to split %q", iface) result = append(result, iface)
continue
}
var name string
var index, mtu int
var up, broadcast, loopback, pointToPoint, multicast bool
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
if err != nil {
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
continue
}
newIf := net.Interface{
Name: name,
Index: index,
MTU: mtu,
}
if up {
newIf.Flags |= net.FlagUp
}
if broadcast {
newIf.Flags |= net.FlagBroadcast
}
if loopback {
newIf.Flags |= net.FlagLoopback
}
if pointToPoint {
newIf.Flags |= net.FlagPointToPoint
}
if multicast {
newIf.Flags |= net.FlagMulticast
}
ifc := transport.NewInterface(newIf)
addrs := strings.Trim(fields[1], " \n")
foundAddress := false
for _, addr := range strings.Split(addrs, " ") {
ip, ipNet, err := net.ParseCIDR(addr)
if err != nil {
log.Warnf("%s", err)
continue
}
ipNet.IP = ip
ifc.AddAddress(ipNet)
foundAddress = true
}
if foundAddress {
ifs = append(ifs, ifc)
} }
} }
return ifs return result
} }

View File

@@ -6,4 +6,4 @@
#define EXPAND(x) STRINGIZE(x) #define EXPAND(x) STRINGIZE(x)
CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
7 ICON ui/netbird.ico 7 ICON ui/netbird.ico
wireguard.dll RCDATA wireguard.dll wintun.dll RCDATA wintun.dll

View File

@@ -43,6 +43,15 @@ func extractUserAgent(ctx context.Context) string {
return "" return ""
} }
// extractDeviceName extracts device name from context or returns the default system name
func extractDeviceName(ctx context.Context, defaultName string) string {
v, ok := ctx.Value(DeviceNameCtxKey).(string)
if !ok {
return defaultName
}
return v
}
// GetDesktopUIUserAgent returns the Desktop ui user agent // GetDesktopUIUserAgent returns the Desktop ui user agent
func GetDesktopUIUserAgent() string { func GetDesktopUIUserAgent() string {
return "netbird-desktop-ui/" + version.NetbirdVersion() return "netbird-desktop-ui/" + version.NetbirdVersion()

View File

@@ -24,21 +24,13 @@ func GetInfo(ctx context.Context) *Info {
} }
gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname = extractDeviceName(ctx) gio.Hostname = extractDeviceName(ctx, "android")
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)
return gio return gio
} }
func extractDeviceName(ctx context.Context) string {
v, ok := ctx.Value(DeviceNameCtxKey).(string)
if !ok {
return "android"
}
return v
}
func uname() []string { func uname() []string {
res := run("/system/bin/uname", "-a") res := run("/system/bin/uname", "-a")
return strings.Split(res, " ") return strings.Split(res, " ")

View File

@@ -32,7 +32,8 @@ func GetInfo(ctx context.Context) *Info {
swVersion = []byte(release) swVersion = []byte(release)
} }
gio := &Info{Kernel: sysName, OSVersion: strings.TrimSpace(string(swVersion)), Core: release, Platform: machine, OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: sysName, OSVersion: strings.TrimSpace(string(swVersion)), Core: release, Platform: machine, OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -24,7 +24,8 @@ func GetInfo(ctx context.Context) *Info {
osStr = strings.Replace(osStr, "\r\n", "", -1) osStr = strings.Replace(osStr, "\r\n", "", -1)
osInfo := strings.Split(osStr, " ") osInfo := strings.Split(osStr, " ")
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -50,7 +50,8 @@ func GetInfo(ctx context.Context) *Info {
osName = osInfo[3] osName = osInfo[3]
} }
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: osInfo[2], OS: osName, OSVersion: osVer, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: osInfo[2], OS: osName, OSVersion: osVer, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -24,3 +24,12 @@ func Test_UIVersion(t *testing.T) {
got := GetInfo(ctx) got := GetInfo(ctx)
assert.Equal(t, want, got.UIVersion) assert.Equal(t, want, got.UIVersion)
} }
func Test_CustomHostname(t *testing.T) {
// nolint
ctx := context.WithValue(context.Background(), DeviceNameCtxKey, "custom-host")
want := "custom-host"
got := GetInfo(ctx)
assert.Equal(t, want, got.Hostname)
}

View File

@@ -16,7 +16,8 @@ import (
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
ver := getOSVersion() ver := getOSVersion()
gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -1,27 +0,0 @@
#!/bin/bash
ldir=$PWD
tmp_dir_path=$ldir/.distfiles
winnt=wireguard-nt.zip
download_file_path=$tmp_dir_path/$winnt
download_url=https://download.wireguard.com/wireguard-nt/wireguard-nt-0.10.1.zip
download_sha=772c0b1463d8d2212716f43f06f4594d880dea4f735165bd68e388fc41b81605
function resources_windows(){
cmd=$1
arch=$2
out=$3
docker run -i --rm -v $PWD:$PWD -w $PWD mstorsjo/llvm-mingw:latest $cmd -O coff -c 65001 -I $tmp_dir_path/wireguard-nt/bin/$arch -i resources.rc -o $out
}
mkdir -p $tmp_dir_path
curl -L#o $download_file_path.unverified $download_url
echo "$download_sha $download_file_path.unverified" | sha256sum -c
mv $download_file_path.unverified $download_file_path
mkdir -p .deps
unzip $download_file_path -d $tmp_dir_path
resources_windows i686-w64-mingw32-windres x86 resources_windows_386.syso
resources_windows aarch64-w64-mingw32-windres arm64 resources_windows_arm64.syso
resources_windows x86_64-w64-mingw32-windres amd64 resources_windows_amd64.syso

7
go.mod
View File

@@ -19,7 +19,7 @@ require (
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.7.0 golang.org/x/crypto v0.7.0
golang.org/x/sys v0.6.0 golang.org/x/sys v0.6.0
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1 golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.52.3 google.golang.org/grpc v1.52.3
@@ -48,6 +48,8 @@ require (
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/open-policy-agent/opa v0.49.0 github.com/open-policy-agent/opa v0.49.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
github.com/pion/stun v0.4.0
github.com/pion/transport/v2 v2.0.2 github.com/pion/transport/v2 v2.0.2
github.com/prometheus/client_golang v1.14.0 github.com/prometheus/client_golang v1.14.0
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
@@ -103,10 +105,8 @@ require (
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.6 // indirect github.com/pion/dtls/v2 v2.2.6 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.7 // indirect github.com/pion/mdns v0.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect github.com/pion/randutil v0.1.0 // indirect
github.com/pion/stun v0.4.0 // indirect
github.com/pion/turn/v2 v2.1.0 // indirect github.com/pion/turn/v2 v2.1.0 // indirect
github.com/pion/udp/v2 v2.0.1 // indirect github.com/pion/udp/v2 v2.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -131,7 +131,6 @@ require (
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/text v0.8.0 // indirect golang.org/x/text v0.8.0 // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.6.0 // indirect
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d // indirect
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect

5
go.sum
View File

@@ -881,13 +881,12 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI= golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 h1:3zl8RkJNQ8wfPRomwv/6DBbH2Ut6dgMaWTxM0ZunWnE= golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI= golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ= golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=

208
iface/bind/bind.go Normal file
View File

@@ -0,0 +1,208 @@
package bind
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"syscall"
"github.com/pion/stun"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
)
// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library
type ICEBind struct {
// below fields, initialized on open
ipv4 net.PacketConn
udpMux *UniversalUDPMuxDefault
// below are fields initialized on creation
transportNet transport.Net
mu sync.Mutex
}
// NewICEBind create a new instance of ICEBind with a given transportNet function.
// The transportNet can be nil.
func NewICEBind(transportNet transport.Net) *ICEBind {
return &ICEBind{
transportNet: transportNet,
mu: sync.Mutex{},
}
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.udpMux, nil
}
// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.ipv4 != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
var err error
b.ipv4, _, err = listenNet("udp4", int(uport))
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.ipv4, Net: b.transportNet})
portAddr, err := netip.ParseAddrPort(b.ipv4.LocalAddr().String())
if err != nil {
return nil, 0, err
}
log.Infof("opened ICEBind on %s", b.ipv4.LocalAddr().String())
return []conn.ReceiveFunc{
b.makeReceiveIPv4(b.ipv4),
},
portAddr.Port(), nil
}
func listenNet(network string, port int) (net.PacketConn, int, error) {
c, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
lAddr := c.LocalAddr()
uAddr, err := net.ResolveUDPAddr(
lAddr.Network(),
lAddr.String(),
)
if err != nil {
return nil, 0, err
}
return c, uAddr.Port, nil
}
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff) {
// WireGuard traffic
return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil
}
msg, err := parseSTUNMessage(buff[:n])
if err != nil {
return 0, nil, err
}
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
}
}
// Close closes the WireGuard socket and UDPMux
func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
var err1, err2 error
if b.ipv4 != nil {
c := b.ipv4
b.ipv4 = nil
err1 = c.Close()
}
if b.udpMux != nil {
m := b.udpMux
b.udpMux = nil
err2 = m.Close()
}
if err1 != nil {
return err1
}
return err2
}
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
func (b *ICEBind) SetMark(mark uint32) error {
return nil
}
// Send bytes to the remote endpoint (peer)
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
nend, ok := endpoint.(conn.StdNetEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
addrPort := netip.AddrPort(nend)
_, err := b.ipv4.WriteTo(buff, &net.UDPAddr{
IP: addrPort.Addr().AsSlice(),
Port: int(addrPort.Port()),
Zone: addrPort.Addr().Zone(),
})
return err
}
// ParseEndpoint creates a new endpoint from a string.
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
e, err := netip.ParseAddrPort(s)
return asEndpoint(e), err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]conn.Endpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) conn.Endpoint {
m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = conn.Endpoint(conn.StdNetEndpoint(ap))
m[ap] = e
}
return e
}

445
iface/bind/udp_mux.go Normal file
View File

@@ -0,0 +1,445 @@
package bind
import (
"fmt"
"io"
"net"
"strings"
"sync"
"github.com/pion/ice/v2"
"github.com/pion/stun"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/transport/v2"
)
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
mu sync.Mutex
// for UDP connection listen at unspecified address
localAddrsForUnspecified []net.Addr
}
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
// Required for gathering local addresses
// in case a un UDPConn is passed which does not
// bind to a specific local address.
Net transport.Net
InterfaceFilter func(interfaceName string) bool
}
func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []ice.NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
ips := []net.IP{}
ifaces, err := n.Interfaces()
if err != nil {
return ips, err
}
var IPv4Requested, IPv6Requested bool
for _, typ := range networkTypes {
if typ.IsIPv4() {
IPv4Requested = true
}
if typ.IsIPv6() {
IPv6Requested = true
}
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
continue // loopback interface
}
if interfaceFilter != nil && !interfaceFilter(iface.Name) {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch addr := addr.(type) {
case *net.IPNet:
ip = addr.IP
case *net.IPAddr:
ip = addr.IP
}
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
continue
}
if ipv4 := ip.To4(); ipv4 == nil {
if !IPv6Requested {
continue
} else if !isSupportedIPv6(ip) {
continue
}
} else if !IPv4Requested {
continue
}
if ipFilter != nil && !ipFilter(ip) {
continue
}
ips = append(ips, ip)
}
}
return ips, nil
}
// The conditions of invalidation written below are defined in
// https://tools.ietf.org/html/rfc8445#section-5.1.1.1
func isSupportedIPv6(ip net.IP) bool {
if len(ip) != net.IPv6len ||
isZeros(ip[0:12]) || // !(IPv4-compatible IPv6)
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast)
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() {
return false
}
return true
}
func isZeros(ip net.IP) bool {
for i := 0; i < len(ip); i++ {
if ip[i] != 0 {
return false
}
}
return true
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
var localAddrsForUnspecified []net.Addr
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
localAddrsForUnspecified: localAddrsForUnspecified,
}
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified
}
return []net.Addr{m.LocalAddr()}
}
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
m.mu.Lock()
defer m.mu.Unlock()
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
}
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
}
m.mu.Unlock()
if len(removedConns) == 0 {
// No need to lock if no connection was found
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != ufrag {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
default:
return false
}
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, c := range m.connsIPv4 {
_ = c.Close()
}
for _, c := range m.connsIPv6 {
_ = c.Close()
}
m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)
close(m.closedChan)
_ = m.params.UDPConn.Close()
})
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.Unlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}
type bufferHolder struct {
buf []byte
}
func newBufferHolder(size int) *bufferHolder {
return &bufferHolder{
buf: make([]byte, size),
}
}

View File

@@ -0,0 +1,254 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
*/
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/v2"
)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped
}
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25
}
m := &UniversalUDPMuxDefault{
params: params,
xorMappedMap: make(map[string]*xorMapped),
}
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
}
// GetListenAddresses returns the listen addr of this UDP
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}
}
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
// Not implemented yet.
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
return nil, fmt.Errorf("not implemented yet")
}
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
// All other STUN packets will be forwarded to the UDPMux
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {
log.Debugf("%s: %v", fmt.Errorf("failed to get XOR-MAPPED-ADDRESS response"), err)
return nil
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock()
defer m.mu.Unlock()
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
_, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok
}
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
// and set the mapped address for the server
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock()
defer m.mu.Unlock()
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
if !ok {
return fmt.Errorf("no XOR address mapping")
}
var addr stun.XORMappedAddress
if err := addr.GetFrom(msg); err != nil {
return err
}
m.xorMappedMap[stunAddr.String()] = mappedAddr
mappedAddr.SetAddr(&addr)
return nil
}
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
// Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// if we already have a mapping for this STUN server (address already received)
// and if it is not too old we return it without making a new request to STUN server
if ok {
if mappedAddr.expired() {
mappedAddr.closeWaiters()
delete(m.xorMappedMap, serverAddr.String())
ok = false
} else if mappedAddr.pending() {
ok = false
}
}
m.mu.Unlock()
if ok {
return mappedAddr.addr, nil
}
// otherwise, make a STUN request to discover the address
// or wait for already sent request to complete
waitAddrReceived, err := m.sendStun(serverAddr)
if err != nil {
return nil, fmt.Errorf("%s: %s", "failed to send STUN packet", err)
}
// block until response was handled by the connWorker routine and XORMappedAddress was updated
select {
case <-waitAddrReceived:
// when channel closed, addr was obtained
m.mu.Lock()
mappedAddr := *m.xorMappedMap[serverAddr.String()]
m.mu.Unlock()
if mappedAddr.addr == nil {
return nil, fmt.Errorf("no XOR address mapping")
}
return mappedAddr.addr, nil
case <-time.After(deadline):
return nil, fmt.Errorf("timeout while waiting for XORMappedAddr")
}
}
// sendStun sends a STUN request via UDP conn.
//
// The returned channel is closed when the STUN response has been received.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) {
m.mu.Lock()
defer m.mu.Unlock()
// if record present in the map, we already sent a STUN request,
// just wait when waitAddrReceived will be closed
addrMap, ok := m.xorMappedMap[serverAddr.String()]
if !ok {
addrMap = &xorMapped{
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
waitAddrReceived: make(chan struct{}),
}
m.xorMappedMap[serverAddr.String()] = addrMap
}
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
if err != nil {
return nil, err
}
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
return nil, err
}
return addrMap.waitAddrReceived, nil
}
type xorMapped struct {
addr *stun.XORMappedAddress
waitAddrReceived chan struct{}
expiresAt time.Time
}
func (a *xorMapped) closeWaiters() {
select {
case <-a.waitAddrReceived:
// notify was close, ok, that means we received duplicate response
// just exit
break
default:
// notify tha twe have a new addr
close(a.waitAddrReceived)
}
}
func (a *xorMapped) pending() bool {
return a.addr == nil
}
func (a *xorMapped) expired() bool {
return a.expiresAt.Before(time.Now())
}
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
a.addr = addr
a.closeWaiters()
}

View File

@@ -0,0 +1,233 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
import (
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v2/packetio"
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
buf *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
}
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
buf: packetio.NewBuffer(),
closedChan: make(chan struct{}),
}
return p
}
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// read address
total, err := c.buf.Read(buf.buf)
if err != nil {
return 0, nil, err
}
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
}
// read data and then address
offset := 2
copy(b, buf.buf[offset:offset+dataLen])
offset += dataLen
// read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
offset += 2
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
return 0, nil, err
}
return dataLen, rAddr, nil
}
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
}
// each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
}
return c.params.Mux.writeTo(buf, rAddr)
}
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
}
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
}
func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buf.Close()
close(c.closedChan)
})
return err
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
}
func (c *udpMuxedConn) getAddresses() []string {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
func (c *udpMuxedConn) addAddress(addr string) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
}
func (c *udpMuxedConn) containsAddress(addr string) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
}
}
return false
}
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buf) < len(data)+maxAddrSize {
return io.ErrShortBuffer
}
// data len
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
offset := 2
// data
copy(buf.buf[offset:], data)
offset += len(data)
// write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
if err != nil {
return err
}
total := offset + n + 2
// address len
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
return err
}
return nil
}
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
ipData, err := addr.IP.MarshalText()
if err != nil {
return 0, err
}
total := 2 + len(ipData) + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
}
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
offset := 2
n := copy(buf[offset:], ipData)
offset += n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2
copy(buf[offset:], addr.Zone)
return total, nil
}
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// basic bounds checking
if ipLen+offset > len(buf) {
return nil, io.ErrShortBuffer
}
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err
}
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)
return &addr, nil
}

View File

@@ -5,6 +5,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -19,6 +21,17 @@ type WGIface struct {
tun *tunDevice tun *tunDevice
configurer wGConfigurer configurer wGConfigurer
mu sync.Mutex mu sync.Mutex
userspaceBind bool
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind
}
// GetBind returns a userspace implementation of WireGuard Bind interface
func (w *WGIface) GetBind() *bind.ICEBind {
return w.tun.iceBind
} }
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
@@ -26,7 +39,7 @@ type WGIface struct {
func (w *WGIface) Create() error { func (w *WGIface) Create() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("create Wireguard interface %s", w.tun.DeviceName()) log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create() return w.tun.Create()
} }

View File

@@ -1,22 +1,28 @@
package iface package iface
import "sync" import (
"sync"
// NewWGIFace Creates a new Wireguard interface instance "github.com/pion/transport/v2"
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) { )
wgIface := &WGIface{
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return wgIface, err return wgIFace, err
} }
tun := newTunDevice(wgAddress, mtu, tunAdapter) tun := newTunDevice(wgAddress, mtu, routes, tunAdapter, transportNet)
wgIface.tun = tun wgIFace.tun = tun
wgIface.configurer = newWGConfigurer(tun) wgIFace.configurer = newWGConfigurer(tun)
return wgIface, nil wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
} }

View File

@@ -2,21 +2,26 @@
package iface package iface
import "sync" import (
"sync"
// NewWGIFace Creates a new Wireguard interface instance "github.com/pion/transport/v2"
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) { )
wgIface := &WGIface{
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return wgIface, err return wgIFace, err
} }
wgIface.tun = newTunDevice(ifaceName, wgAddress, mtu) wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet)
wgIface.configurer = newWGConfigurer(ifaceName) wgIFace.configurer = newWGConfigurer(iFaceName)
return wgIface, nil wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
} }

View File

@@ -2,13 +2,15 @@ package iface
import ( import (
"fmt" "fmt"
"net"
"testing"
"time"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"testing"
"time"
) )
// keep darwin compability // keep darwin compability
@@ -32,7 +34,12 @@ func init() {
func TestWGIface_UpdateAddr(t *testing.T) { func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8" addr := "100.64.0.1/8"
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -92,7 +99,11 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) { func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32" wgIP := "10.99.99.1/32"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -121,7 +132,11 @@ func Test_CreateInterface(t *testing.T) {
func Test_Close(t *testing.T) { func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32" wgIP := "10.99.99.2/32"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -149,7 +164,11 @@ func Test_Close(t *testing.T) {
func Test_ConfigureInterface(t *testing.T) { func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30" wgIP := "10.99.99.5/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -196,7 +215,11 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) { func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30" wgIP := "10.99.99.9/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -255,7 +278,11 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30" wgIP := "10.99.99.13/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil) newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -304,8 +331,11 @@ func Test_ConnectPeers(t *testing.T) {
peer2Key, _ := wgtypes.GeneratePrivateKey() peer2Key, _ := wgtypes.GeneratePrivateKey()
keepAlive := 1 * time.Second keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet()
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil) if err != nil {
t.Fatal(err)
}
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -322,7 +352,11 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil) newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -3,7 +3,7 @@
package iface package iface
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireguardModuleIsLoaded() bool { func WireGuardModuleIsLoaded() bool {
return false return false
} }

View File

@@ -7,9 +7,6 @@ import (
"bufio" "bufio"
"errors" "errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"io" "io"
"io/fs" "io/fs"
"math" "math"
@@ -17,6 +14,10 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
) )
// Holds logic to check existence of kernel modules used by wireguard interfaces // Holds logic to check existence of kernel modules used by wireguard interfaces
@@ -33,6 +34,7 @@ const (
loading loading
live live
inuse inuse
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
) )
type module struct { type module struct {
@@ -81,9 +83,15 @@ func tunModuleIsLoaded() bool {
return tunLoaded return tunLoaded
} }
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireguardModuleIsLoaded() bool { func WireGuardModuleIsLoaded() bool {
if canCreateFakeWireguardInterface() {
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if canCreateFakeWireGuardInterface() {
return true return true
} }
@@ -96,7 +104,7 @@ func WireguardModuleIsLoaded() bool {
return loaded return loaded
} }
func canCreateFakeWireguardInterface() bool { func canCreateFakeWireGuardInterface() bool {
link := newWGLink("mustnotexist") link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid // We willingly try to create a device with an invalid

View File

@@ -3,13 +3,14 @@ package iface
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
) )
func TestGetModuleDependencies(t *testing.T) { func TestGetModuleDependencies(t *testing.T) {

View File

@@ -2,6 +2,6 @@ package iface
// TunAdapter is an interface for create tun device from externel service // TunAdapter is an interface for create tun device from externel service
type TunAdapter interface { type TunAdapter interface {
ConfigureInterface(address string, mtu int) (int, error) ConfigureInterface(address string, mtu int, routes string) (int, error)
UpdateAddr(address string) error UpdateAddr(address string) error
} }

View File

@@ -1,38 +1,43 @@
package iface package iface
import ( import (
"net" "strings"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind"
) )
type tunDevice struct { type tunDevice struct {
address WGAddress address WGAddress
mtu int mtu int
routes []string
tunAdapter TunAdapter tunAdapter TunAdapter
fd int fd int
name string name string
device *device.Device device *device.Device
uapi net.Listener iceBind *bind.ICEBind
} }
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter) *tunDevice { func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
return &tunDevice{ return &tunDevice{
address: address, address: address,
mtu: mtu, mtu: mtu,
routes: routes,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
iceBind: bind.NewICEBind(transportNet),
} }
} }
func (t *tunDevice) Create() error { func (t *tunDevice) Create() error {
var err error var err error
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu) routesString := t.routesToString()
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)
return err return err
@@ -46,35 +51,11 @@ func (t *tunDevice) Create() error {
t.name = name t.name = name
log.Debugf("attaching to interface %v", name) log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(tunDevice, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device.DisableSomeRoamingForBrokenMobileSemantics() t.device.DisableSomeRoamingForBrokenMobileSemantics()
log.Debugf("create uapi")
tunSock, err := ipc.UAPIOpen(name)
if err != nil {
return err
}
t.uapi, err = ipc.UAPIListen(name, tunSock)
if err != nil {
tunSock.Close()
unix.Close(t.fd)
return err
}
go func() {
for {
uapiConn, err := t.uapi.Accept()
if err != nil {
return
}
go t.device.IpcHandle(uapiConn)
}
}()
err = t.device.Up() err = t.device.Up()
if err != nil { if err != nil {
tunSock.Close()
t.device.Close() t.device.Close()
return err return err
} }
@@ -100,13 +81,13 @@ func (t *tunDevice) UpdateAddr(addr WGAddress) error {
} }
func (t *tunDevice) Close() (err error) { func (t *tunDevice) Close() (err error) {
if t.uapi != nil {
err = t.uapi.Close()
}
if t.device != nil { if t.device != nil {
t.device.Close() t.device.Close()
} }
return return
} }
func (t *tunDevice) routesToString() string {
return strings.Join(t.routes, ";")
}

View File

@@ -11,7 +11,7 @@ import (
) )
func (c *tunDevice) Create() error { func (c *tunDevice) Create() error {
if WireguardModuleIsLoaded() { if WireGuardModuleIsLoaded() {
log.Info("using kernel WireGuard") log.Info("using kernel WireGuard")
return c.createWithKernel() return c.createWithKernel()
} }
@@ -30,7 +30,7 @@ func (c *tunDevice) Create() error {
} }
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module. // createWithKernel Creates a new WireGuard interface using kernel WireGuard module.
// Works for Linux and offers much better network performance // Works for Linux and offers much better network performance
func (c *tunDevice) createWithKernel() error { func (c *tunDevice) createWithKernel() error {

View File

@@ -6,10 +6,13 @@ import (
"net" "net"
"os" "os"
log "github.com/sirupsen/logrus" "github.com/pion/transport/v2"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@@ -18,13 +21,18 @@ type tunDevice struct {
address WGAddress address WGAddress
mtu int mtu int
netInterface NetInterface netInterface NetInterface
iceBind *bind.ICEBind
uapi net.Listener
close chan struct{}
} }
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice { func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
return &tunDevice{ return &tunDevice{
name: name, name: name,
address: address, address: address,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
close: make(chan struct{}),
} }
} }
@@ -42,23 +50,38 @@ func (c *tunDevice) DeviceName() string {
} }
func (c *tunDevice) Close() error { func (c *tunDevice) Close() error {
if c.netInterface == nil {
return nil select {
case c.close <- struct{}{}:
default:
} }
err := c.netInterface.Close()
if err != nil { var err1, err2, err3 error
return err if c.netInterface != nil {
err1 = c.netInterface.Close()
}
if c.uapi != nil {
err2 = c.uapi.Close()
} }
sockPath := "/var/run/wireguard/" + c.name + ".sock" sockPath := "/var/run/wireguard/" + c.name + ".sock"
if _, statErr := os.Stat(sockPath); statErr == nil { if _, statErr := os.Stat(sockPath); statErr == nil {
statErr = os.Remove(sockPath) statErr = os.Remove(sockPath)
if statErr != nil { if statErr != nil {
return statErr err3 = statErr
} }
} }
return nil if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
} }
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation // createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
@@ -69,26 +92,36 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
} }
// We need to create a wireguard-go device and listen to configuration requests // We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDevice.Up() err = tunDev.Up()
if err != nil { if err != nil {
return tunIface, err _ = tunIface.Close()
return nil, err
} }
// todo: after this line in case of error close the tunSock c.uapi, err = c.getUAPI(c.name)
uapi, err := c.getUAPI(c.name)
if err != nil { if err != nil {
return tunIface, err _ = tunIface.Close()
return nil, err
} }
go func() { go func() {
for { for {
uapiConn, uapiErr := uapi.Accept() select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil { if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr) log.Traceln("uapi Accept failed with error: ", uapiErr)
continue continue
} }
go tunDevice.IpcHandle(uapiConn) go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
} }
}() }()

View File

@@ -4,24 +4,39 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/iface/bind"
) )
type tunDevice struct { type tunDevice struct {
name string name string
address WGAddress address WGAddress
netInterface NetInterface netInterface NetInterface
iceBind *bind.ICEBind
mtu int
uapi net.Listener
close chan struct{}
} }
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice { func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
return &tunDevice{name: name, address: address} return &tunDevice{
name: name,
address: address,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
close: make(chan struct{}),
}
} }
func (c *tunDevice) Create() error { func (c *tunDevice) Create() error {
var err error var err error
c.netInterface, err = c.createAdapter() c.netInterface, err = c.createWithUserspace()
if err != nil { if err != nil {
return err return err
} }
@@ -29,6 +44,51 @@ func (c *tunDevice) Create() error {
return c.assignAddr() return c.assignAddr()
} }
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (c *tunDevice) createWithUserspace() (NetInterface, error) {
tunIface, err := tun.CreateTUN(c.name, c.mtu)
if err != nil {
return nil, err
}
// We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDev.Up()
if err != nil {
_ = tunIface.Close()
return nil, err
}
c.uapi, err = c.getUAPI(c.name)
if err != nil {
_ = tunIface.Close()
return nil, err
}
go func() {
for {
select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr)
continue
}
go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
}
}()
log.Debugln("UAPI listener started")
return tunIface, nil
}
func (c *tunDevice) UpdateAddr(address WGAddress) error { func (c *tunDevice) UpdateAddr(address WGAddress) error {
c.address = address c.address = address
return c.assignAddr() return c.assignAddr()
@@ -43,19 +103,33 @@ func (c *tunDevice) DeviceName() string {
} }
func (c *tunDevice) Close() error { func (c *tunDevice) Close() error {
if c.netInterface == nil { select {
return nil case c.close <- struct{}{}:
default:
} }
return c.netInterface.Close() var err1, err2 error
if c.netInterface != nil {
err1 = c.netInterface.Close()
}
if c.uapi != nil {
err2 = c.uapi.Close()
}
if err1 != nil {
return err1
}
return err2
} }
func (c *tunDevice) getInterfaceGUIDString() (string, error) { func (c *tunDevice) getInterfaceGUIDString() (string, error) {
if c.netInterface == nil { if c.netInterface == nil {
return "", fmt.Errorf("interface has not been initialized yet") return "", fmt.Errorf("interface has not been initialized yet")
} }
windowsDevice := c.netInterface.(*driver.Adapter) windowsDevice := c.netInterface.(*tun.NativeTun)
luid := windowsDevice.LUID() luid := winipcfg.LUID(windowsDevice.LUID())
guid, err := luid.GUID() guid, err := luid.GUID()
if err != nil { if err != nil {
return "", err return "", err
@@ -63,31 +137,15 @@ func (c *tunDevice) getInterfaceGUIDString() (string, error) {
return guid.String(), nil return guid.String(), nil
} }
func (c *tunDevice) createAdapter() (NetInterface, error) {
WintunStaticRequestedGUID, _ := windows.GenerateGUID()
adapter, err := driver.CreateAdapter(c.name, "WireGuard", &WintunStaticRequestedGUID)
if err != nil {
err = fmt.Errorf("error creating adapter: %w", err)
return nil, err
}
err = adapter.SetAdapterState(driver.AdapterStateUp)
if err != nil {
return adapter, err
}
state, _ := adapter.LUID().GUID()
log.Debugln("device guid: ", state.String())
return adapter, nil
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (c *tunDevice) assignAddr() error { func (c *tunDevice) assignAddr() error {
luid := c.netInterface.(*driver.Adapter).LUID() tunDev := c.netInterface.(*tun.NativeTun)
luid := winipcfg.LUID(tunDev.LUID())
log.Debugf("adding address %s to interface: %s", c.address.IP, c.name) log.Debugf("adding address %s to interface: %s", c.address.IP, c.name)
err := luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}}) return luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
if err != nil { }
return err
} // getUAPI returns a Listener
func (c *tunDevice) getUAPI(iface string) (net.Listener, error) {
return nil return ipc.UAPIListen(iface)
} }

View File

@@ -172,6 +172,49 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil return nil
} }
// GetRoutes return with the routes
func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
}
defer func() {
_ = stream.CloseSend()
}()
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return nil, err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return nil, err
}
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return nil, err
}
if decryptedResp.GetNetworkMap() == nil {
return nil, fmt.Errorf("invalid msg, required network map")
}
return decryptedResp.GetNetworkMap().GetRoutes(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{} req := &proto.SyncRequest{}

View File

@@ -49,15 +49,17 @@ type AccountManager interface {
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string) (*SetupKey, error) autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, userID string, key *UserInfo) (*UserInfo, error) CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, executingUserID string, targetUserID string) error
ListSetupKeys(accountID, userID string) ([]*SetupKey, error) ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, userID string, update *User) (*UserInfo, error) SaveUser(accountID, userID string, update *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountByUserID(userID string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error MarkPATUsed(tokenID string) error
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(userID string) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error) GetPeerByKey(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error) GetPeers(accountID, userID string) ([]*Peer, error)
@@ -177,6 +179,7 @@ type UserInfo struct {
Role string `json:"role"` Role string `json:"role"`
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
Status string `json:"-"` Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
} }
// getRoutesToSync returns the enabled routes for the peer ID and the routes // getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -283,7 +286,7 @@ func (a *Account) GetGroup(groupID string) *Group {
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise // GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
aclPeers, _ := a.getPeersByPolicy(peerID) aclPeers := a.getPeersByACL(peerID)
// exclude expired peers // exclude expired peers
var peersToConnect []*Peer var peersToConnect []*Peer
var expiredPeers []*Peer var expiredPeers []*Peer
@@ -1228,10 +1231,12 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if !user.IsServiceUser {
err = am.redeemInvite(account, claims.UserId) err = am.redeemInvite(account, claims.UserId)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
}
return account, user, nil return account, user, nil
} }

View File

@@ -87,6 +87,10 @@ const (
PersonalAccessTokenCreated PersonalAccessTokenCreated
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token // PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted PersonalAccessTokenDeleted
// ServiceUserCreated indicates that a user created a service user
ServiceUserCreated
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted
) )
const ( const (
@@ -176,6 +180,10 @@ const (
PersonalAccessTokenCreatedMessage string = "Personal access token created" PersonalAccessTokenCreatedMessage string = "Personal access token created"
// PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity // PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity
PersonalAccessTokenDeletedMessage string = "Personal access token deleted" PersonalAccessTokenDeletedMessage string = "Personal access token deleted"
// ServiceUserCreatedMessage is a human-readable text message of the ServiceUserCreated activity
ServiceUserCreatedMessage string = "Service user created"
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
ServiceUserDeletedMessage string = "Service user deleted"
) )
// Activity that triggered an Event // Activity that triggered an Event
@@ -270,6 +278,10 @@ func (a Activity) Message() string {
return PersonalAccessTokenCreatedMessage return PersonalAccessTokenCreatedMessage
case PersonalAccessTokenDeleted: case PersonalAccessTokenDeleted:
return PersonalAccessTokenDeletedMessage return PersonalAccessTokenDeletedMessage
case ServiceUserCreated:
return ServiceUserCreatedMessage
case ServiceUserDeleted:
return ServiceUserDeletedMessage
default: default:
return "UNKNOWN_ACTIVITY" return "UNKNOWN_ACTIVITY"
} }
@@ -364,6 +376,10 @@ func (a Activity) StringCode() string {
return "personal.access.token.create" return "personal.access.token.create"
case PersonalAccessTokenDeleted: case PersonalAccessTokenDeleted:
return "personal.access.token.delete" return "personal.access.token.delete"
case ServiceUserCreated:
return "service.user.create"
case ServiceUserDeleted:
return "service.user.delete"
default: default:
return "UNKNOWN_ACTIVITY" return "UNKNOWN_ACTIVITY"
} }

View File

@@ -286,9 +286,7 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
} }
if accountCopy.Rules == nil {
accountCopy.Rules = make(map[string]*Rule) accountCopy.Rules = make(map[string]*Rule)
}
for _, policy := range accountCopy.Policies { for _, policy := range accountCopy.Policies {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
accountCopy.Rules[rule.ID] = rule.ToRule() accountCopy.Rules[rule.ID] = rule.ToRule()

View File

@@ -77,6 +77,10 @@ components:
description: Is true if authenticated user is the same as this user description: Is true if authenticated user is the same as this user
type: boolean type: boolean
readOnly: true readOnly: true
is_service_user:
description: Is true if this user is a service user
type: boolean
readOnly: true
required: required:
- id - id
- email - email
@@ -115,10 +119,13 @@ components:
type: array type: array
items: items:
type: string type: string
is_service_user:
description: Is true if this user is a service user
type: boolean
required: required:
- role - role
- auto_groups - auto_groups
- email - is_service_user
PeerMinimum: PeerMinimum:
type: object type: object
properties: properties:
@@ -825,6 +832,12 @@ paths:
tags: [ Users ] tags: [ Users ]
security: security:
- BearerAuth: [ ] - BearerAuth: [ ]
parameters:
- in: query
name: service_user
schema:
type: boolean
description: Filters users and returns either normal users or service users
responses: responses:
'200': '200':
description: A JSON array of Users description: A JSON array of Users
@@ -903,6 +916,30 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
delete:
summary: Delete a User
tags: [ Users ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The User ID
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/{userId}/tokens: /api/users/{userId}/tokens:
get: get:
summary: Returns a list of all tokens for a user summary: Returns a list of all tokens for a user

View File

@@ -677,6 +677,9 @@ type User struct {
// IsCurrent Is true if authenticated user is the same as this user // IsCurrent Is true if authenticated user is the same as this user
IsCurrent *bool `json:"is_current,omitempty"` IsCurrent *bool `json:"is_current,omitempty"`
// IsServiceUser Is true if this user is a service user
IsServiceUser *bool `json:"is_service_user,omitempty"`
// Name User's name from idp provider // Name User's name from idp provider
Name string `json:"name"` Name string `json:"name"`
@@ -696,7 +699,10 @@ type UserCreateRequest struct {
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Email User's Email to send invite to // Email User's Email to send invite to
Email string `json:"email"` Email *string `json:"email,omitempty"`
// IsServiceUser Is true if this user is a service user
IsServiceUser bool `json:"is_service_user"`
// Name User's full name // Name User's full name
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
@@ -787,6 +793,12 @@ type PutApiRulesIdJSONBody struct {
Sources *[]string `json:"sources,omitempty"` Sources *[]string `json:"sources,omitempty"`
} }
// GetApiUsersParams defines parameters for GetApiUsers.
type GetApiUsersParams struct {
// ServiceUser Filters users and returns either normal users or service users
ServiceUser *bool `form:"service_user,omitempty" json:"service_user,omitempty"`
}
// PutApiAccountsIdJSONRequestBody defines body for PutApiAccountsId for application/json ContentType. // PutApiAccountsIdJSONRequestBody defines body for PutApiAccountsId for application/json ContentType.
type PutApiAccountsIdJSONRequestBody PutApiAccountsIdJSONBody type PutApiAccountsIdJSONRequestBody PutApiAccountsIdJSONBody

View File

@@ -111,6 +111,7 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") apiHandler.Router.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{id}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
} }

View File

@@ -12,7 +12,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) type IsUserAdminFunc func(userID string) (bool, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct { type AccessControl struct {
@@ -37,7 +37,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := a.claimsExtract.FromRequestContext(r) claims := a.claimsExtract.FromRequestContext(r)
ok, err := a.isUserAdmin(claims) ok, err := a.isUserAdmin(claims.UserId)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return return

View File

@@ -3,8 +3,10 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
@@ -77,6 +79,36 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
} }
// DeleteUser is a DELETE request to delete a user (only works for service users right now)
func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["id"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
}
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). // CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@@ -103,11 +135,17 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
email := ""
if req.Email != nil {
email = *req.Email
}
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{ newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{
Email: req.Email, Email: email,
Name: *req.Name, Name: *req.Name,
Role: req.Role, Role: req.Role,
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
IsServiceUser: req.IsServiceUser,
}) })
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@@ -137,9 +175,27 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
return return
} }
serviceUser := r.URL.Query().Get("service_user")
log.Debugf("UserCount: %v", len(data))
users := make([]*api.User, 0) users := make([]*api.User, 0)
for _, r := range data { for _, r := range data {
if serviceUser == "" {
users = append(users, toUserResponse(r, claims.UserId)) users = append(users, toUserResponse(r, claims.UserId))
continue
}
includeServiceUser, err := strconv.ParseBool(serviceUser)
log.Debugf("Should include service user: %v", includeServiceUser)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return
}
log.Debugf("User %v is service user: %v", r.Name, r.IsServiceUser)
if includeServiceUser == r.IsServiceUser {
log.Debugf("Found service user: %v", r.Name)
users = append(users, toUserResponse(r, claims.UserId))
}
} }
util.WriteJSONObject(w, users) util.WriteJSONObject(w, users)
@@ -170,5 +226,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: userStatus, Status: userStatus,
IsCurrent: &isCurrent, IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser,
} }
} }

View File

@@ -1,52 +1,91 @@
package http package http
import ( import (
"bytes"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/magiconair/properties/assert" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
) )
func initUsers(user ...*server.User) *UsersHandler { const (
serviceUserID = "serviceUserID"
regularUserID = "regularUserID"
)
var usersTestAccount = &server.Account{
Id: existingAccountID,
Domain: domain,
Users: map[string]*server.User{
existingUserID: {
Id: existingUserID,
Role: "admin",
IsServiceUser: false,
},
regularUserID: {
Id: regularUserID,
Role: "user",
IsServiceUser: false,
},
serviceUserID: {
Id: serviceUserID,
Role: "user",
IsServiceUser: true,
},
},
}
func initUsersTestData() *UsersHandler {
return &UsersHandler{ return &UsersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
users := make(map[string]*server.User, 0) return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
for _, u := range user {
users[u.Id] = u
}
return &server.Account{
Id: "12345",
Domain: "netbird.io",
Users: users,
}, users[claims.UserId], nil
}, },
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0) users := make([]*server.UserInfo, 0)
for _, v := range user { for _, v := range usersTestAccount.Users {
users = append(users, &server.UserInfo{ users = append(users, &server.UserInfo{
ID: v.Id, ID: v.Id,
Role: string(v.Role), Role: string(v.Role),
Name: "", Name: "",
Email: "", Email: "",
IsServiceUser: v.IsServiceUser,
}) })
} }
return users, nil return users, nil
}, },
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
return key, nil
},
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
if !usersTestAccount.Users[targetUserID].IsServiceUser {
return status.Errorf(status.PermissionDenied, "user with ID %s is not a service user and can not be deleted", targetUserID)
}
return nil
},
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "1", UserId: existingUserID,
Domain: "hotmail.com", Domain: domain,
AccountId: "test_id", AccountId: existingAccountID,
} }
}), }),
), ),
@@ -54,8 +93,84 @@ func initUsers(user ...*server.User) *UsersHandler {
} }
func TestGetUsers(t *testing.T) { func TestGetUsers(t *testing.T) {
users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}} tt := []struct {
userHandler := initUsers(users...) name string
expectedStatus int
requestType string
requestPath string
expectedUserIDs []string
}{
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
userHandler.GetAllUsers(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
respBody := []*server.UserInfo{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, len(respBody), len(tc.expectedUserIDs))
for _, v := range respBody {
assert.Contains(t, tc.expectedUserIDs, v.ID)
assert.Equal(t, v.ID, usersTestAccount.Users[v.ID].Id)
assert.Equal(t, v.Role, string(usersTestAccount.Users[v.ID].Role))
assert.Equal(t, v.IsServiceUser, usersTestAccount.Users[v.ID].IsServiceUser)
}
})
}
}
func TestCreateUser(t *testing.T) {
name := "name"
email := "email"
serviceUserToAdd := api.UserCreateRequest{
AutoGroups: []string{},
Email: nil,
IsServiceUser: true,
Name: &name,
Role: "admin",
}
serviceUserString, err := json.Marshal(serviceUserToAdd)
if err != nil {
t.Fatal(err)
}
regularUserToAdd := api.UserCreateRequest{
AutoGroups: []string{},
Email: &email,
IsServiceUser: true,
Name: &name,
Role: "admin",
}
regularUserString, err := json.Marshal(regularUserToAdd)
if err != nil {
t.Fatal(err)
}
tt := []struct { tt := []struct {
name string name string
@@ -65,40 +180,79 @@ func TestGetUsers(t *testing.T) {
requestBody io.Reader requestBody io.Reader
expectedResult []*server.User expectedResult []*server.User
}{ }{
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users}, {name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)},
// right now creation is blocked in AC middleware, will be refactored in the future
{name: "CreateRegularUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(regularUserString)},
} }
userHandler := initUsersTestData()
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
userHandler.GetAllUsers(rr, req) userHandler.CreateUser(rr, req)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus { if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK) status, tc.expectedStatus)
} }
})
content, err := io.ReadAll(res.Body) }
if err != nil { }
t.Fatal(err)
} func TestDeleteUser(t *testing.T) {
tt := []struct {
respBody := []*server.UserInfo{} name string
err = json.Unmarshal(content, &respBody) expectedStatus int
if err != nil { expectedBody bool
t.Fatalf("Sent content is not in correct json format; %v", err) requestType string
} requestPath string
requestVars map[string]string
if tc.expectedResult != nil { requestBody io.Reader
for i, resp := range respBody { }{
assert.Equal(t, resp.ID, tc.expectedResult[i].Id) {
assert.Equal(t, string(resp.Role), string(tc.expectedResult[i].Role)) name: "Delete Regular User",
} requestType: http.MethodDelete,
requestPath: "/api/users/" + regularUserID,
requestVars: map[string]string{"id": regularUserID},
expectedStatus: http.StatusForbidden,
},
{
name: "Delete Service User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + serviceUserID,
requestVars: map[string]string{"id": serviceUserID},
expectedStatus: http.StatusOK,
},
{
name: "Delete Not Existing User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + notFoundUserID,
requestVars: map[string]string{"id": notFoundUserID},
expectedStatus: http.StatusNotFound,
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
rr := httptest.NewRecorder()
userHandler.DeleteUser(rr, req)
res := rr.Result()
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
} }
}) })
} }

View File

@@ -15,12 +15,12 @@ import (
type MockAccountManager struct { type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error) GetAccountByUserIDFunc func(userID string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdminFunc func(userID string) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error) AccountExistsFunc func(accountId string) (*bool, error)
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error) GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
@@ -61,6 +61,7 @@ type MockAccountManager struct {
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
@@ -112,12 +113,12 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
) )
} }
// GetAccountByUser mock implementation of GetAccountByUser from server.AccountManager interface // GetAccountByUserID mock implementation of GetAccountByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, error) { func (am *MockAccountManager) GetAccountByUserID(userID string) (*server.Account, error) {
if am.GetAccountByUserFunc != nil { if am.GetAccountByUserIDFunc != nil {
return am.GetAccountByUserFunc(userId) return am.GetAccountByUserIDFunc(userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserID is not implemented")
} }
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface // CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
@@ -394,9 +395,9 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
} }
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface // IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *MockAccountManager) IsUserAdmin(userID string) (bool, error) {
if am.IsUserAdminFunc != nil { if am.IsUserAdminFunc != nil {
return am.IsUserAdminFunc(claims) return am.IsUserAdminFunc(userID)
} }
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented") return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
} }
@@ -500,6 +501,14 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
} }
// DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil {
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
}
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
}
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if am.GetNameServerGroupFunc != nil { if am.GetNameServerGroupFunc != nil {

View File

@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
// TODO: use firewall rules // TODO: use firewall rules
aclPeers, _ := account.getPeersByPolicy(peer.ID) aclPeers := account.getPeersByACL(peer.ID)
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@@ -816,7 +816,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*Pee
} }
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.getPeersByPolicy(p.ID) aclPeers := account.getPeersByACL(p.ID)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID { if aclPeer.ID == peerID {
return peer, nil return peer, nil
@@ -833,6 +833,98 @@ func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) *Peer {
return peer return peer
} }
// GetPeerRules returns a list of source or destination rules of a given peer.
func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) {
// Rules are group based so there is no direct access to peers.
// First, find all groups that the given peer belongs to
peerGroups := make(map[string]struct{})
for s, group := range a.Groups {
for _, peer := range group.Peers {
if peerID == peer {
peerGroups[s] = struct{}{}
break
}
}
}
// Second, find all rules that have discovered source and destination groups
srcRulesMap := make(map[string]*Rule)
dstRulesMap := make(map[string]*Rule)
for _, rule := range a.Rules {
for _, g := range rule.Source {
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
srcRules = append(srcRules, rule)
srcRulesMap[rule.ID] = rule
}
}
for _, g := range rule.Destination {
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
dstRules = append(dstRules, rule)
dstRulesMap[rule.ID] = rule
}
}
}
return srcRules, dstRules
}
// getPeersByACL returns all peers that given peer has access to.
func (a *Account) getPeersByACL(peerID string) []*Peer {
var peers []*Peer
srcRules, dstRules := a.GetPeerRules(peerID)
groups := map[string]*Group{}
for _, r := range srcRules {
if r.Disabled {
continue
}
if r.Flow == TrafficFlowBidirect {
for _, gid := range r.Destination {
if group, ok := a.Groups[gid]; ok {
groups[gid] = group
}
}
}
}
for _, r := range dstRules {
if r.Disabled {
continue
}
if r.Flow == TrafficFlowBidirect {
for _, gid := range r.Source {
if group, ok := a.Groups[gid]; ok {
groups[gid] = group
}
}
}
}
peersSet := make(map[string]struct{})
for _, g := range groups {
for _, pid := range g.Peers {
peer, ok := a.Peers[pid]
if !ok {
log.Warnf(
"peer %s found in group %s but doesn't belong to account %s",
pid,
g.ID,
a.Id,
)
continue
}
// exclude original peer
if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok {
peersSet[peer.ID] = struct{}{}
peers = append(peers, peer.Copy())
}
}
}
return peers
}
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {

View File

@@ -136,6 +136,8 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
} }
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
// TODO: disable until we start use policy again
t.Skip()
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -4,11 +4,11 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@@ -44,6 +44,9 @@ type UserRole string
type User struct { type User struct {
Id string Id string
Role UserRole Role UserRole
IsServiceUser bool
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string AutoGroups []string
PATs map[string]*PersonalAccessToken PATs map[string]*PersonalAccessToken
@@ -65,10 +68,11 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
return &UserInfo{ return &UserInfo{
ID: u.Id, ID: u.Id,
Email: "", Email: "",
Name: "", Name: u.ServiceUserName,
Role: string(u.Role), Role: string(u.Role),
AutoGroups: u.AutoGroups, AutoGroups: u.AutoGroups,
Status: string(UserStatusActive), Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
}, nil }, nil
} }
if userData.ID != u.Id { if userData.ID != u.Id {
@@ -87,6 +91,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
Role: string(u.Role), Role: string(u.Role),
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: string(userStatus), Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
}, nil }, nil
} }
@@ -104,31 +109,85 @@ func (u *User) Copy() *User {
Id: u.Id, Id: u.Id,
Role: u.Role, Role: u.Role,
AutoGroups: autoGroups, AutoGroups: autoGroups,
IsServiceUser: u.IsServiceUser,
ServiceUserName: u.ServiceUserName,
PATs: pats, PATs: pats,
} }
} }
// NewUser creates a new user // NewUser creates a new user
func NewUser(id string, role UserRole) *User { func NewUser(id string, role UserRole, isServiceUser bool, serviceUserName string, autoGroups []string) *User {
return &User{ return &User{
Id: id, Id: id,
Role: role, Role: role,
AutoGroups: []string{}, IsServiceUser: isServiceUser,
ServiceUserName: serviceUserName,
AutoGroups: autoGroups,
} }
} }
// NewRegularUser creates a new user with role UserRoleAdmin // NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User { func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser) return NewUser(id, UserRoleUser, false, "", []string{})
} }
// NewAdminUser creates a new user with role UserRoleAdmin // NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User { func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin) return NewUser(id, UserRoleAdmin, false, "", []string{})
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
}
executingUser := account.Users[executingUserID]
if executingUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin {
return nil, status.Errorf(status.PermissionDenied, "only admins can create service users")
}
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, serviceUserName, autoGroups)
log.Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
meta := map[string]any{"name": newUser.ServiceUserName}
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
return &UserInfo{
ID: newUser.Id,
Email: "",
Name: newUser.ServiceUserName,
Role: string(newUser.Role),
AutoGroups: newUser.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: true,
}, nil
} }
// CreateUser creates a new user under the given account. Effectively this is a user invite. // CreateUser creates a new user under the given account. Effectively this is a user invite.
func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) { func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *UserInfo) (*UserInfo, error) {
if user.IsServiceUser {
return am.createServiceUser(accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.AutoGroups)
}
return am.inviteNewUser(accountID, userID, user)
}
// inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -193,8 +252,48 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us
} }
// DeleteUser deletes a user from the given account.
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
targetUser := account.Users[targetUserID]
if targetUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin {
return status.Errorf(status.PermissionDenied, "only admins can delete service users")
}
if !targetUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "regular users can not be deleted")
}
meta := map[string]any{"name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
delete(account.Users, targetUserID)
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
return nil
}
// CreatePAT creates a new PAT for the given user // CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -206,21 +305,26 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
} }
if executingUserID != targetUserId {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
}
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser := account.Users[targetUserId] targetUser := account.Users[targetUserID]
if targetUser == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "targetUser not found") return nil, status.Errorf(status.NotFound, "targetUser not found")
} }
pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id) executingUser := account.Users[executingUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
}
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
if err != nil { if err != nil {
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
} }
@@ -232,8 +336,8 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
return nil, status.Errorf(status.Internal, "failed to save account: %v", err) return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
} }
meta := map[string]any{"name": pat.Name} meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserId, accountID, activity.PersonalAccessTokenCreated, meta) am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
return pat, nil return pat, nil
} }
@@ -243,21 +347,26 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
if executingUserID != targetUserID {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
}
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(status.NotFound, "account not found: %s", err) return status.Errorf(status.NotFound, "account not found: %s", err)
} }
user := account.Users[targetUserID] targetUser := account.Users[targetUserID]
if user == nil { if targetUser == nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
pat := user.PATs[tokenID] executingUser := account.Users[executingUserID]
if targetUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
}
pat := targetUser.PATs[tokenID]
if pat == nil { if pat == nil {
return status.Errorf(status.NotFound, "PAT not found") return status.Errorf(status.NotFound, "PAT not found")
} }
@@ -271,10 +380,10 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
} }
meta := map[string]any{"name": pat.Name} meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
delete(user.PATs, tokenID) delete(targetUser.PATs, tokenID)
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
@@ -288,21 +397,26 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
if executingUserID != targetUserID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(status.NotFound, "account not found: %s", err) return nil, status.Errorf(status.NotFound, "account not found: %s", err)
} }
user := account.Users[targetUserID] targetUser := account.Users[targetUserID]
if user == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
pat := user.PATs[tokenID] executingUser := account.Users[executingUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
}
pat := targetUser.PATs[tokenID]
if pat == nil { if pat == nil {
return nil, status.Errorf(status.NotFound, "PAT not found") return nil, status.Errorf(status.NotFound, "PAT not found")
} }
@@ -315,22 +429,27 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
if executingUserID != targetUserID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(status.NotFound, "account not found: %s", err) return nil, status.Errorf(status.NotFound, "account not found: %s", err)
} }
user := account.Users[targetUserID] targetUser := account.Users[targetUserID]
if user == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
executingUser := account.Users[executingUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
var pats []*PersonalAccessToken var pats []*PersonalAccessToken
for _, pat := range user.PATs { for _, pat := range targetUser.PATs {
pats = append(pats, pat) pats = append(pats, pat)
} }
@@ -386,7 +505,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
group := account.GetGroup(g) group := account.GetGroup(g)
if group != nil { if group != nil {
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser, am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID}) map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else { } else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
} }
@@ -397,14 +516,14 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
group := account.GetGroup(g) group := account.GetGroup(g)
if group != nil { if group != nil {
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser, am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID}) map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else { } else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
} }
} }
}() }()
if !isNil(am.idpManager) { if !isNil(am.idpManager) && !newUser.IsServiceUser {
userData, err := am.lookupUserInCache(newUser.Id, account) userData, err := am.lookupUserInCache(newUser.Id, account)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -454,14 +573,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
return account, nil return account, nil
} }
// IsUserAdmin flag for current user authenticated by JWT token // GetAccountByUserID returns an existing account for a given user id
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *DefaultAccountManager) GetAccountByUserID(userID string) (*Account, error) {
account, _, err := am.GetAccountFromToken(claims) return am.Store.GetAccountByUser(userID)
}
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
func (am *DefaultAccountManager) IsUserAdmin(userID string) (bool, error) {
account, err := am.GetAccountByUserID(userID)
if err != nil { if err != nil {
return false, fmt.Errorf("get account: %v", err) return false, fmt.Errorf("get account: %v", err)
} }
user, ok := account.Users[claims.UserId] user, ok := account.Users[userID]
if !ok { if !ok {
return false, status.Errorf(status.NotFound, "user not found") return false, status.Errorf(status.NotFound, "user not found")
} }
@@ -486,8 +610,10 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
users := make(map[string]struct{}, len(account.Users)) users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users { for _, user := range account.Users {
if !user.IsServiceUser {
users[user.Id] = struct{}{} users[user.Id] = struct{}{}
} }
}
queriedUsers, err = am.lookupCache(users, accountID) queriedUsers, err = am.lookupCache(users, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -512,20 +638,44 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
return userInfos, nil return userInfos, nil
} }
for _, queriedUser := range queriedUsers { for _, localUser := range account.Users {
if !user.IsAdmin() && user.Id != queriedUser.ID { if !user.IsAdmin() && user.Id != localUser.Id {
// if user is not an admin then show only current user and do not show other users // if user is not an admin then show only current user and do not show other users
continue continue
} }
if localUser, contains := account.Users[queriedUser.ID]; contains {
info, err := localUser.toUserInfo(queriedUser) var info *UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.toUserInfo(queriedUser)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userInfos = append(userInfos, info) } else {
name := ""
if localUser.IsServiceUser {
name = localUser.ServiceUserName
} }
info = &UserInfo{
ID: localUser.Id,
Email: "",
Name: name,
Role: string(localUser.Role),
AutoGroups: localUser.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: localUser.IsServiceUser,
}
}
userInfos = append(userInfos, info)
} }
return userInfos, nil return userInfos, nil
} }
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {
return user, true
}
}
return nil, false
}

View File

@@ -1,8 +1,12 @@
package server package server
import ( import (
"fmt"
"reflect"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@@ -11,6 +15,9 @@ import (
const ( const (
mockAccountID = "accountID" mockAccountID = "accountID"
mockUserID = "userID" mockUserID = "userID"
mockServiceUserID = "serviceUserID"
mockRole = "user"
mockServiceUserName = "serviceUserName"
mockTargetUserId = "targetUserID" mockTargetUserId = "targetUserID"
mockTokenID1 = "tokenID1" mockTokenID1 = "tokenID1"
mockToken1 = "SoMeHaShEdToKeN1" mockToken1 = "SoMeHaShEdToKeN1"
@@ -41,6 +48,8 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
t.Fatalf("Error when adding PAT to user: %s", err) t.Fatalf("Error when adding PAT to user: %s", err)
} }
assert.Equal(t, pat.CreatedBy, mockUserID)
fileStore := am.Store.(*FileStore) fileStore := am.Store.(*FileStore)
tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken] tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken]
@@ -60,7 +69,10 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
Id: mockTargetUserId,
IsServiceUser: false,
}
err := store.SaveAccount(account) err := store.SaveAccount(account)
if err != nil { if err != nil {
t.Fatalf("Error when saving account: %s", err) t.Fatalf("Error when saving account: %s", err)
@@ -75,6 +87,31 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
assert.Errorf(t, err, "Creating PAT for different user should thorw error") assert.Errorf(t, err, "Creating PAT for different user should thorw error")
} }
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
Id: mockTargetUserId,
IsServiceUser: true,
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
pat, err := am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
if err != nil {
t.Fatalf("Error when adding PAT to user: %s", err)
}
assert.Equal(t, pat.CreatedBy, mockUserID)
}
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
@@ -207,3 +244,300 @@ func TestUser_GetAllPATs(t *testing.T) {
assert.Equal(t, 2, len(pats)) assert.Equal(t, 2, len(pats))
} }
func TestUser_Copy(t *testing.T) {
// this is an imaginary case which will never be in DB this way
user := User{
Id: "userId",
Role: "role",
IsServiceUser: true,
ServiceUserName: "servicename",
AutoGroups: []string{"group1", "group2"},
PATs: map[string]*PersonalAccessToken{
"pat1": {
ID: "pat1",
Name: "First PAT",
HashedToken: "SoMeHaShEdToKeN",
ExpirationDate: time.Now().AddDate(0, 0, 7),
CreatedBy: "userId",
CreatedAt: time.Now(),
LastUsed: time.Now(),
},
},
}
err := validateStruct(user)
if err != nil {
t.Fatalf("Test needs update: dummy struct has not all fields set : %s", err)
}
copiedUser := user.Copy()
assert.True(t, cmp.Equal(user, *copiedUser))
}
// based on https://medium.com/@anajankow/fast-check-if-all-struct-fields-are-set-in-golang-bba1917213d2
func validateStruct(s interface{}) (err error) {
structType := reflect.TypeOf(s)
structVal := reflect.ValueOf(s)
fieldNum := structVal.NumField()
for i := 0; i < fieldNum; i++ {
field := structVal.Field(i)
fieldName := structType.Field(i).Name
isSet := field.IsValid() && !field.IsZero()
if !isSet {
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
}
}
return err
}
func TestUser_CreateServiceUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
user, err := am.createServiceUser(mockAccountID, mockUserID, mockRole, mockServiceUserName, []string{"group1", "group2"})
if err != nil {
t.Fatalf("Error when creating service user: %s", err)
}
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
assert.NotNil(t, store.Accounts[mockAccountID].Users[user.ID])
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
assert.Equal(t, map[string]*PersonalAccessToken{}, store.Accounts[mockAccountID].Users[user.ID].PATs)
assert.Zero(t, user.Email)
assert.True(t, user.IsServiceUser)
assert.Equal(t, "active", user.Status)
}
func TestUser_CreateUser_ServiceUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
user, err := am.CreateUser(mockAccountID, mockUserID, &UserInfo{
Name: mockServiceUserName,
Role: mockRole,
IsServiceUser: true,
AutoGroups: []string{"group1", "group2"},
})
if err != nil {
t.Fatalf("Error when creating user: %s", err)
}
assert.True(t, user.IsServiceUser)
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
assert.Equal(t, mockServiceUserName, user.Name)
assert.Equal(t, mockRole, user.Role)
assert.Equal(t, []string{"group1", "group2"}, user.AutoGroups)
assert.Equal(t, "active", user.Status)
}
func TestUser_CreateUser_RegularUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
_, err = am.CreateUser(mockAccountID, mockUserID, &UserInfo{
Name: mockServiceUserName,
Role: mockRole,
IsServiceUser: false,
AutoGroups: []string{"group1", "group2"},
})
assert.Errorf(t, err, "Not configured IDP will throw error but right path used")
}
func TestUser_DeleteUser_ServiceUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
Id: mockServiceUserID,
IsServiceUser: true,
ServiceUserName: mockServiceUserName,
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
err = am.DeleteUser(mockAccountID, mockUserID, mockServiceUserID)
if err != nil {
t.Fatalf("Error when deleting user: %s", err)
}
assert.Equal(t, 1, len(store.Accounts[mockAccountID].Users))
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
}
func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
}
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
ok, err := am.IsUserAdmin(mockUserID)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.True(t, ok)
}
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
Role: "user",
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
ok, err := am.IsUserAdmin(mockUserID)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.False(t, ok)
}
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
Id: mockServiceUserID,
Role: "user",
IsServiceUser: true,
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
users, err := am.GetUsersFromAccount(mockAccountID, mockUserID)
if err != nil {
t.Fatalf("Error when getting users from account: %s", err)
}
assert.Equal(t, 2, len(users))
}
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
Id: mockServiceUserID,
Role: "user",
IsServiceUser: true,
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
users, err := am.GetUsersFromAccount(mockAccountID, mockServiceUserID)
if err != nil {
t.Fatalf("Error when getting users from account: %s", err)
}
assert.Equal(t, 1, len(users))
assert.Equal(t, mockServiceUserID, users[0].ID)
}