Compare commits

..

24 Commits

Author SHA1 Message Date
Bethuel Mmbaga
cf7f6c355f [misc] Remove default zitadel admin user in deployment script (#4482)
* Delete default zitadel-admin user during initialization

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-09-11 21:20:10 +02:00
Zoltan Papp
47e64d72db [client] Fix client status check (#4474)
The client status is not enough to protect the RPC calls from concurrency issues, because it is handled internally in the client in an asynchronous way.
2025-09-11 16:21:09 +02:00
Zoltan Papp
9e81e782e5 [client] Fix/v4 stun routing (#4430)
Deduplicate STUN package sending.
Originally, because every peer shared the same UDP address, the library could not distinguish which STUN message was associated with which candidate. As a result, the Pion library responded from all candidates for every STUN message.
2025-09-11 10:08:54 +02:00
Zoltan Papp
7aef0f67df [client] Implement environment variable handling for Android (#4440)
Some features can only be manipulated via environment variables. With this PR, environment variables can be managed from Android.
2025-09-08 18:42:42 +02:00
Maycon Santos
dba7ef667d [misc] Remove aur support and start service on ostree (#4461)
* Remove aur support and start service on ostree

The aur installation was adding many packages and installing more than just the client. For now is best to remove it and rely on binary install

Some users complained about ostree installation not starting the client, we add two explicit commands to it

* use  ${SUDO}

* fix if closure
2025-09-08 15:03:56 +02:00
Zoltan Papp
69d87343d2 [client] Debug information for connection (#4439)
Improve logging

Print the exact time when the first WireGuard handshake occurs
Print the steps for gathering system information
2025-09-08 14:51:34 +02:00
Bethuel Mmbaga
5113c70943 [management] Extends integration and peers manager (#4450) 2025-09-06 13:13:49 +03:00
Zoltan Papp
ad8fcda67b [client] Move some sys info to static place (#4446)
This PR refactors the system information collection code by moving static system information gathering to a dedicated location and separating platform-specific implementations. The primary goal is to improve code organization and maintainability by centralizing static info collection logic.

Key changes:
- Centralized static info collection into dedicated files with platform-specific implementations
- Moved `StaticInfo` struct definition to the main static_info.go file
- Added async initialization function `UpdateStaticInfoAsync()` across all platforms
2025-09-06 10:49:28 +02:00
Pascal Fischer
d33f88df82 [management] only allow user devices to be expired (#4445) 2025-09-05 18:11:23 +02:00
Zoltan Papp
786ca6fc79 Do not block Offer processing from relay worker (#4435)
- do not miss ICE offers when relay worker busy
- close p2p connection before recreate agent
2025-09-05 11:02:29 +02:00
Diego Romar
dfebdf1444 [internal] Add missing assignment of iFaceDiscover when netstack is disabled (#4444)
The internal updateInterfaces() function expects iFaceDiscover to not
be nil
2025-09-04 23:00:10 +02:00
Bethuel Mmbaga
a8dcff69c2 [management] Add peers manager to integrations (#4405) 2025-09-04 23:07:03 +03:00
Viktor Liu
71e944fa57 [relay] Let relay accept any origin (#4426) 2025-09-01 19:51:06 +02:00
Maycon Santos
d39fcfd62a [management] Add user approval (#4411)
This PR adds user approval functionality to the management system, allowing administrators to manually approve new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin.

Adds UserApprovalRequired setting to control manual user approval requirement
Introduces user approval and rejection endpoints with corresponding business logic
Prevents pending approval users from adding peers or logging in
2025-09-01 18:00:45 +02:00
Zoltan Papp
21368b38d9 [client] Update Pion ICE to the latest version (#4388)
- Update Pion version
- Update protobuf version
2025-09-01 10:42:01 +02:00
Maycon Santos
d817584f52 [misc] fix Windows client and management bench tests (#4424)
Windows tests had too many directories, causing issues to the payload via psexec.

Also migrated all checked benchmarks to send data to grafana.
2025-08-31 17:19:56 +02:00
Pascal Fischer
4d3dc3475d [management] remove duplicated removal of groups on peer delete (#4421) 2025-08-30 12:47:13 +02:00
Pascal Fischer
6fc50a438f [management] remove withContext from store methods (#4422) 2025-08-30 12:46:54 +02:00
Vlad
149559a06b [management] login filter to fix multiple peers connected with the same pub key (#3986) 2025-08-29 19:48:40 +02:00
Pascal Fischer
e14c6de203 [management] fix ephemeral flag on peer batch response (#4420) 2025-08-29 17:41:20 +02:00
Viktor Liu
d4c067f0af [client] Don't deactivate upstream resolvers on failure (#4128) 2025-08-29 17:40:05 +02:00
Pascal Fischer
dbefa8bd9f [management] remove lock and continue user update on failure (#4410) 2025-08-28 17:50:12 +02:00
Pascal Fischer
4fd10b9447 [management] split high latency grpc metrics (#4408) 2025-08-28 13:25:40 +02:00
Viktor Liu
aa595c3073 [client] Fix shared sock buffer allocation (#4409) 2025-08-28 13:25:16 +02:00
119 changed files with 4814 additions and 1347 deletions

View File

@@ -382,6 +382,32 @@ jobs:
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
@@ -428,9 +454,10 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \ CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \ go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)" name: "Management / Benchmark (API)"
@@ -521,7 +548,7 @@ jobs:
-run=^$ \ -run=^$ \
-bench=. \ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/server/http/...
api_integration_test: api_integration_test:
name: "Management / Integration" name: "Management / Integration"
@@ -571,4 +598,4 @@ jobs:
CI=true \ CI=true \
go test -tags=integration \ go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/server/http/...

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -4,6 +4,7 @@ package android
import ( import (
"context" "context"
"os"
"slices" "slices"
"sync" "sync"
@@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps. // In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() { func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener() c.recorder.RemoveConnectionListener()
} }
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
}
}
}

View File

@@ -0,0 +1,32 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
)
// EnvList wraps a Go map for export to Java
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}

View File

@@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
} }
// update host's static platform and system information // update host's static platform and system information
system.UpdateStaticInfo() system.UpdateStaticInfoAsync()
configFilePath, err := activeProf.FilePath() configFilePath, err := activeProf.FilePath()
if err != nil { if err != nil {

View File

@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
log.Info("starting NetBird service") //nolint log.Info("starting NetBird service") //nolint
// Collect static system and platform information // Collect static system and platform information
system.UpdateStaticInfo() system.UpdateStaticInfoAsync()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer() p.serv = grpc.NewServer()

View File

@@ -9,29 +9,26 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
mgmt "github.com/netbirdio/netbird/management/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server" sig "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
) )
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
@@ -90,15 +87,20 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT(). settingsMockManager.EXPECT().

View File

@@ -8,13 +8,14 @@ import (
"runtime" "runtime"
"sync" "sync"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -44,7 +45,7 @@ type ICEBind struct {
RecvChan chan RecvMessage RecvChan chan RecvMessage
transportNet transport.Net transportNet transport.Net
filterFn FilterFn filterFn udpmux.FilterFn
endpoints map[netip.Addr]net.Conn endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
@@ -54,13 +55,13 @@ type ICEBind struct {
closed bool closed bool
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
address wgaddr.Address address wgaddr.Address
mtu uint16 mtu uint16
activityRecorder *ActivityRecorder activityRecorder *ActivityRecorder
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
} }
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind // GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
if s.udpMux == nil { if s.udpMux == nil {
@@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = udpmux.NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn), UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,

View File

@@ -1,7 +0,0 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil { if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
} }
// If sec is 0 (Unix epoch), return zero time instead
// This indicates no handshake has occurred
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(sec, 0), nil return time.Unix(sec, 0), nil
} }

View File

@@ -7,14 +7,14 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16 MTU() uint16

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -29,7 +30,7 @@ type WGTunDevice struct {
name string name string
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -26,7 +27,7 @@ type TunDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -28,7 +29,7 @@ type TunDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -12,8 +12,8 @@ import (
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
link *wgLink link *wgLink
udpMuxConn net.PacketConn udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
filterFn bind.FilterFn filterFn udpmux.FilterFn
} }
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil return configurer, nil
} }
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.udpMux != nil { if t.udpMux != nil {
return t.udpMux, nil return t.udpMux, nil
} }
@@ -106,14 +106,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
udpConn = nbnet.WrapPacketConn(rawSock) udpConn = nbnet.WrapPacketConn(rawSock)
} }
bindParams := bind.UniversalUDPMuxParams{ bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: udpConn, UDPConn: udpConn,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address, WGAddress: t.address,
MTU: t.mtu, MTU: t.mtu,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock t.udpMuxConn = rawSock
t.udpMux = mux t.udpMux = mux

View File

@@ -10,6 +10,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -26,7 +27,7 @@ type TunNetstackDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun nsTun *nbnetstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
net *netstack.Net net *netstack.Net
@@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -25,7 +26,7 @@ type USPDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -29,7 +30,7 @@ type TunDevice struct {
device *device.Device device *device.Device
nativeTunDevice *tun.NativeTun nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -5,14 +5,14 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16 MTU() uint16

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
MTU uint16 MTU uint16
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn udpmux.FilterFn
DisableDNS bool DisableDNS bool
} }
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface // Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before) // The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()

View File

@@ -1,4 +1,4 @@
package bind package udpmux
/* /*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,11 +16,12 @@ import (
) )
type udpMuxedConnParams struct { type udpMuxedConnParams struct {
Mux *UDPMuxDefault Mux *SingleSocketUDPMux
AddrPool *sync.Pool AddrPool *sync.Pool
Key string Key string
LocalAddr net.Addr LocalAddr net.Addr
Logger logging.LeveledLogger Logger logging.LeveledLogger
CandidateID string
} }
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
return err return err
} }
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool { func (c *udpMuxedConn) isClosed() bool {
select { select {
case <-c.closedChan: case <-c.closedChan:

View File

@@ -0,0 +1,64 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,4 +1,4 @@
package bind package udpmux
import ( import (
"fmt" "fmt"
@@ -8,9 +8,9 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192 const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface // SingleSocketUDPMux is an implementation of the interface
type UDPMuxDefault struct { type SingleSocketUDPMux struct {
params UDPMuxParams params Params
closedChan chan struct{} closedChan chan struct{}
closeOnce sync.Once closeOnce sync.Once
@@ -32,6 +32,9 @@ type UDPMuxDefault struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn addressMap map[string][]*udpMuxedConn
@@ -46,8 +49,8 @@ type UDPMuxDefault struct {
const maxAddrSize = 512 const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux. // Params are parameters for UDPMux.
type UDPMuxParams struct { type Params struct {
Logger logging.LeveledLogger Logger logging.LeveledLogger
UDPConn net.PacketConn UDPConn net.PacketConn
@@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool {
return true return true
} }
// NewUDPMuxDefault creates an implementation of UDPMux // NewSingleSocketUDPMux creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
if params.Logger == nil { if params.Logger == nil {
params.Logger = getLogger() params.Logger = getLogger()
} }
mux := &UDPMuxDefault{ mux := &SingleSocketUDPMux{
addressMap: map[string][]*udpMuxedConn{}, addressMap: map[string][]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn), connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1), candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{ pool: &sync.Pool{
New: func() interface{} { New: func() interface{} {
// big enough buffer to fit both packet and address // big enough buffer to fit both packet and address
@@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return mux return mux
} }
func (m *UDPMuxDefault) updateLocalAddresses() { func (m *SingleSocketUDPMux) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() { } else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but // For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection // it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux. // with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType var networks []ice.NetworkType
switch { switch {
@@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
m.mu.Unlock() m.mu.Unlock()
} }
// LocalAddr returns the listening address of this UDPMuxDefault // LocalAddr returns the listening address of this SingleSocketUDPMux
func (m *UDPMuxDefault) LocalAddr() net.Addr { func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr() return m.params.UDPConn.LocalAddr()
} }
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
m.updateLocalAddresses() m.updateLocalAddresses()
m.mu.Lock() m.mu.Lock()
@@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address // GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
// don't check addr for mux using unspecified address // don't check addr for mux using unspecified address
m.mu.Lock() m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified) lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return conn, nil return conn, nil
} }
c := m.createMuxedConn(ufrag) c := m.createMuxedConn(ufrag, candidateID)
go func() { go func() {
<-c.CloseChannel() <-c.CloseChannel()
m.RemoveConnByUfrag(ufrag) m.RemoveConnByUfrag(ufrag)
}() }()
m.candidateConnMap[candidateID] = c
if isIPv6 { if isIPv6 {
m.connsIPv6[ufrag] = c m.connsIPv6[ufrag] = c
} else { } else {
@@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
} }
// RemoveConnByUfrag stops and removes the muxed packet connection // RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2) removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock // Keep lock section small to avoid deadlock with conn lock
@@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok { if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag) delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c) removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
} }
if c, ok := m.connsIPv6[ufrag]; ok { if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag) delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c) removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
} }
m.mu.Unlock() m.mu.Unlock()
@@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
} }
// IsClosed returns true if the mux had been closed // IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool { func (m *SingleSocketUDPMux) IsClosed() bool {
select { select {
case <-m.closedChan: case <-m.closedChan:
return true return true
@@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
} }
// Close the mux, no further connections could be created // Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error { func (m *SingleSocketUDPMux) Close() error {
var err error var err error
m.closeOnce.Do(func() { m.closeOnce.Do(func() {
m.mu.Lock() m.mu.Lock()
@@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error {
return err return err
} }
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr) return m.params.UDPConn.WriteTo(buf, rAddr)
} }
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() { if m.IsClosed() {
return return
} }
@@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{ c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m, Mux: m,
Key: key, Key: key,
AddrPool: m.pool, AddrPool: m.pool,
LocalAddr: m.LocalAddr(), LocalAddr: m.LocalAddr(),
Logger: m.params.Logger, Logger: m.params.Logger,
CandidateID: candidateID,
}) })
return c return c
} }
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr) remoteAddr, ok := addr.(*net.UDPAddr)
if !ok { if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr") return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
} }
// If we have already seen this address dispatch to the appropriate destination // Try to route to specific candidate connection first
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one if conn := m.findCandidateConnection(msg); conn != nil {
// muxed connection - one for the SRFLX candidate and the other one for the HOST one. return conn.writePacket(msg.Raw, remoteAddr)
// We will then forward STUN packets to each of these connections. }
m.addressMapMu.RLock()
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn var destinationConnList []*udpMuxedConn
// Add connections from address map
m.addressMapMu.RLock()
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.RUnlock() m.addressMapMu.RUnlock()
var isIPv6 bool if conn, ok := m.findConnectionByUsername(msg, addr); ok {
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { // If we have already seen this address dispatch to the appropriate destination
isIPv6 = true // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
} }
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. // Forward to all found connections
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList { for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err) log.Errorf("could not write packet: %v", err)
} }
} }
return nil return nil
} }
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { // findConnectionByUsername finds connection using username attribute from STUN message
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 { if isIPv6 {
val, ok = m.connsIPv6[ufrag] val, ok = m.connsIPv6[ufrag]
} else { } else {
@@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
return return
} }
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct { type bufferHolder struct {
buf []byte buf []byte
} }

View File

@@ -1,12 +1,12 @@
//go:build !ios //go:build !ios
package bind package udpmux
import ( import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr) conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package bind package udpmux
/* /*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -15,7 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing. // It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct { type UniversalUDPMuxDefault struct {
*UDPMuxDefault *SingleSocketUDPMux
params UniversalUDPMuxParams params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
udpMuxParams := UDPMuxParams{ udpMuxParams := Params{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
Net: m.params.Net, Net: m.params.Net,
} }
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
return m return m
} }
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server. // and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
} }
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
} }
return nil return nil
} }
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
} }
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -275,7 +275,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil { if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
@@ -284,10 +284,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil { if runningChan != nil {
select { close(runningChan)
case runningChan <- struct{}{}: runningChan = nil
default:
}
} }
<-engineCtx.Done() <-engineCtx.Done()

View File

@@ -0,0 +1,201 @@
package config
import (
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
ErrEmptyURL = errors.New("empty URL")
ErrEmptyHost = errors.New("empty host")
ErrIPNotAllowed = errors.New("IP address not allowed")
)
// ServerDomains represents the management server domains extracted from NetBird configuration
type ServerDomains struct {
Signal domain.Domain
Relay []domain.Domain
Flow domain.Domain
Stuns []domain.Domain
Turns []domain.Domain
}
// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
if config == nil {
return ServerDomains{}
}
domains := ServerDomains{}
domains.Signal = extractSignalDomain(config)
domains.Relay = extractRelayDomains(config)
domains.Flow = extractFlowDomain(config)
domains.Stuns = extractStunDomains(config)
domains.Turns = extractTurnDomains(config)
return domains
}
// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func ExtractValidDomain(rawURL string) (domain.Domain, error) {
if rawURL == "" {
return "", ErrEmptyURL
}
parsedURL, err := url.Parse(rawURL)
if err == nil {
if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" {
return domain, err
}
}
return extractFromRawString(rawURL)
}
// extractFromParsedURL handles domain extraction from successfully parsed URLs
func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) {
if parsedURL.Hostname() != "" {
return extractDomainFromHost(parsedURL.Hostname())
}
if parsedURL.Opaque == "" || parsedURL.Scheme == "" {
return "", nil
}
// Handle URLs with opaque content (e.g., stun:host:port)
if strings.Contains(parsedURL.Scheme, ".") {
// This is likely "domain.com:port" being parsed as scheme:opaque
reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque
if host, _, err := net.SplitHostPort(reconstructed); err == nil {
return extractDomainFromHost(host)
}
return extractDomainFromHost(parsedURL.Scheme)
}
// Valid scheme with opaque content (e.g., stun:host:port)
host := parsedURL.Opaque
if queryIndex := strings.Index(host, "?"); queryIndex > 0 {
host = host[:queryIndex]
}
if hostOnly, _, err := net.SplitHostPort(host); err == nil {
return extractDomainFromHost(hostOnly)
}
return extractDomainFromHost(host)
}
// extractFromRawString handles domain extraction when URL parsing fails or returns no results
func extractFromRawString(rawURL string) (domain.Domain, error) {
if host, _, err := net.SplitHostPort(rawURL); err == nil {
return extractDomainFromHost(host)
}
return extractDomainFromHost(rawURL)
}
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
func extractDomainFromHost(host string) (domain.Domain, error) {
if host == "" {
return "", ErrEmptyHost
}
if _, err := netip.ParseAddr(host); err == nil {
return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host)
}
d, err := domain.FromString(host)
if err != nil {
return "", fmt.Errorf("invalid domain: %v", err)
}
return d, nil
}
// extractSingleDomain extracts a single domain from a URL with error logging
func extractSingleDomain(url, serviceType string) domain.Domain {
if url == "" {
return ""
}
d, err := ExtractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
return ""
}
return d
}
// extractMultipleDomains extracts multiple domains from URLs with error logging
func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
var domains []domain.Domain
for _, url := range urls {
if url == "" {
continue
}
d, err := ExtractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
continue
}
domains = append(domains, d)
}
return domains
}
// extractSignalDomain extracts the signal domain from NetBird configuration.
func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Signal != nil {
return extractSingleDomain(config.Signal.Uri, "signal")
}
return ""
}
// extractRelayDomains extracts relay server domains from NetBird configuration.
func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay != nil {
return extractMultipleDomains(config.Relay.Urls, "relay")
}
return nil
}
// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Flow != nil {
return extractSingleDomain(config.Flow.Url, "flow")
}
return ""
}
// extractStunDomains extracts STUN server domains from NetBird configuration.
func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
urls = append(urls, stun.Uri)
}
}
return extractMultipleDomains(urls, "STUN")
}
// extractTurnDomains extracts TURN server domains from NetBird configuration.
func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
urls = append(urls, turn.HostConfig.Uri)
}
}
return extractMultipleDomains(urls, "TURN")
}

View File

@@ -0,0 +1,213 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtractValidDomain(t *testing.T) {
tests := []struct {
name string
url string
expected string
expectError bool
}{
{
name: "HTTPS URL with port",
url: "https://api.netbird.io:443",
expected: "api.netbird.io",
},
{
name: "HTTP URL without port",
url: "http://signal.example.com",
expected: "signal.example.com",
},
{
name: "Host with port (no scheme)",
url: "signal.netbird.io:443",
expected: "signal.netbird.io",
},
{
name: "STUN URL",
url: "stun:stun.netbird.io:443",
expected: "stun.netbird.io",
},
{
name: "STUN URL with different port",
url: "stun:stun.netbird.io:5555",
expected: "stun.netbird.io",
},
{
name: "TURNS URL with query params",
url: "turns:turn.netbird.io:443?transport=tcp",
expected: "turn.netbird.io",
},
{
name: "TURN URL",
url: "turn:turn.example.com:3478",
expected: "turn.example.com",
},
{
name: "REL URL",
url: "rel://relay.example.com:443",
expected: "relay.example.com",
},
{
name: "RELS URL",
url: "rels://relay.netbird.io:443",
expected: "relay.netbird.io",
},
{
name: "Raw hostname",
url: "example.org",
expected: "example.org",
},
{
name: "IP address should be rejected",
url: "192.168.1.1",
expectError: true,
},
{
name: "IP address with port should be rejected",
url: "192.168.1.1:443",
expectError: true,
},
{
name: "IPv6 address should be rejected",
url: "2001:db8::1",
expectError: true,
},
{
name: "HTTP URL with IPv4 should be rejected",
url: "http://192.168.1.1:8080",
expectError: true,
},
{
name: "HTTPS URL with IPv4 should be rejected",
url: "https://10.0.0.1:443",
expectError: true,
},
{
name: "STUN URL with IPv4 should be rejected",
url: "stun:192.168.1.1:3478",
expectError: true,
},
{
name: "TURN URL with IPv4 should be rejected",
url: "turn:10.0.0.1:3478",
expectError: true,
},
{
name: "TURNS URL with IPv4 should be rejected",
url: "turns:172.16.0.1:5349",
expectError: true,
},
{
name: "HTTP URL with IPv6 should be rejected",
url: "http://[2001:db8::1]:8080",
expectError: true,
},
{
name: "HTTPS URL with IPv6 should be rejected",
url: "https://[::1]:443",
expectError: true,
},
{
name: "STUN URL with IPv6 should be rejected",
url: "stun:[2001:db8::1]:3478",
expectError: true,
},
{
name: "IPv6 with port should be rejected",
url: "[2001:db8::1]:443",
expectError: true,
},
{
name: "Localhost IPv4 should be rejected",
url: "127.0.0.1:8080",
expectError: true,
},
{
name: "Localhost IPv6 should be rejected",
url: "[::1]:443",
expectError: true,
},
{
name: "REL URL with IPv4 should be rejected",
url: "rel://192.168.1.1:443",
expectError: true,
},
{
name: "RELS URL with IPv4 should be rejected",
url: "rels://10.0.0.1:443",
expectError: true,
},
{
name: "Empty URL",
url: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ExtractValidDomain(tt.url)
if tt.expectError {
assert.Error(t, err, "Expected error for URL: %s", tt.url)
} else {
assert.NoError(t, err, "Unexpected error for URL: %s", tt.url)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url)
}
})
}
}
func TestExtractDomainFromHost(t *testing.T) {
tests := []struct {
name string
host string
expected string
expectError bool
}{
{
name: "Valid domain",
host: "example.com",
expected: "example.com",
},
{
name: "Subdomain",
host: "api.example.com",
expected: "api.example.com",
},
{
name: "IPv4 address",
host: "192.168.1.1",
expectError: true,
},
{
name: "IPv6 address",
host: "2001:db8::1",
expectError: true,
},
{
name: "Empty host",
host: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractDomainFromHost(tt.host)
if tt.expectError {
assert.Error(t, err, "Expected error for host: %s", tt.host)
} else {
assert.NoError(t, err, "Unexpected error for host: %s", tt.host)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host)
}
})
}
}

View File

@@ -11,11 +11,12 @@ import (
) )
const ( const (
PriorityLocal = 100 PriorityMgmtCache = 150
PriorityDNSRoute = 75 PriorityLocal = 100
PriorityUpstream = 50 PriorityDNSRoute = 75
PriorityDefault = 1 PriorityUpstream = 50
PriorityFallback = -100 PriorityDefault = 1
PriorityFallback = -100
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {
@@ -182,7 +183,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler // If handler wants to continue, try next handler
if chainWriter.shouldContinue { if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler for domain=%s", qname) // Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue continue
} }
return return

View File

@@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool {
// String returns a string representation of the local resolver // String returns a string representation of the local resolver
func (d *Resolver) String() string { func (d *Resolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.records)) return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
} }
func (d *Resolver) Stop() {} func (d *Resolver) Stop() {}

View File

@@ -0,0 +1,360 @@
package mgmt
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
const dnsTimeout = 5 * time.Second
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
}
}
// String returns a string representation of the resolver.
func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// ServeDNS implements dns.Handler interface.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
m.continueToNext(w, r)
return
}
m.mutex.RLock()
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
m.continueToNext(w, r)
return
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write response: %v", err)
}
}
// MatchSubdomains returns false since this resolver only handles exact domain matches
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
func (m *Resolver) MatchSubdomains() bool {
return false
}
// continueToNext signals the handler chain to continue to the next handler.
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write continue signal: %v", err)
}
}
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
if err != nil {
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
}
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
}
m.mutex.Lock()
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {
return nil
}
d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
if err != nil {
return fmt.Errorf("extract domain from URL: %w", err)
}
m.mutex.Lock()
m.mgmtDomain = &d
m.mutex.Unlock()
if err := m.AddDomain(ctx, d); err != nil {
return fmt.Errorf("add domain: %w", err)
}
return nil
}
// RemoveDomain removes a domain from the cache.
func (m *Resolver) RemoveDomain(d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.Lock()
defer m.mutex.Unlock()
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
}
// GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock()
defer m.mutex.RUnlock()
domainSet := make(map[domain.Domain]struct{})
for question := range m.records {
domainName := strings.TrimSuffix(question.Name, ".")
domainSet[domain.Domain(domainName)] = struct{}{}
}
domains := make(domain.List, 0, len(domainSet))
for d := range domainSet {
domains = append(domains, d)
}
return domains
}
// UpdateFromServerDomains updates the cache with server domains from network configuration.
// It merges new domains with existing ones, replacing entire domain types when updated.
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains domain.List
if len(newDomains) > 0 {
m.mutex.Lock()
if m.serverDomains == nil {
m.serverDomains = &dnsconfig.ServerDomains{}
}
updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains)
m.serverDomains = &updatedServerDomains
m.mutex.Unlock()
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
}
m.addNewDomains(ctx, newDomains)
return removedDomains, nil
}
// removeStaleDomains removes cached domains not present in the target domain list.
// Management domains are preserved and never removed during server domain updates.
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
var removedDomains domain.List
for _, currentDomain := range currentDomains {
if m.isDomainInList(currentDomain, newDomains) {
continue
}
if m.isManagementDomain(currentDomain) {
continue
}
removedDomains = append(removedDomains, currentDomain)
if err := m.RemoveDomain(currentDomain); err != nil {
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
}
}
return removedDomains
}
// mergeServerDomains merges new server domains with existing ones.
// When a domain type is provided in the new domains, it completely replaces that type.
func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains {
merged := existing
if incoming.Signal != "" {
merged.Signal = incoming.Signal
}
if len(incoming.Relay) > 0 {
merged.Relay = incoming.Relay
}
if incoming.Flow != "" {
merged.Flow = incoming.Flow
}
if len(incoming.Stuns) > 0 {
merged.Stuns = incoming.Stuns
}
if len(incoming.Turns) > 0 {
merged.Turns = incoming.Turns
}
return merged
}
// isDomainInList checks if domain exists in the list
func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
for _, d := range list {
if domain.SafeString() == d.SafeString() {
return true
}
}
return false
}
// isManagementDomain checks if domain is the protected management domain
func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
var domains domain.List
if serverDomains.Signal != "" {
domains = append(domains, serverDomains.Signal)
}
for _, relay := range serverDomains.Relay {
if relay != "" {
domains = append(domains, relay)
}
}
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
for _, stun := range serverDomains.Stuns {
if stun != "" {
domains = append(domains, stun)
}
}
for _, turn := range serverDomains.Turns {
if turn != "" {
domains = append(domains, turn)
}
}
return domains
}

View File

@@ -0,0 +1,416 @@
package mgmt
import (
"context"
"fmt"
"net/url"
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
func TestResolver_NewResolver(t *testing.T) {
resolver := NewResolver()
assert.NotNil(t, resolver)
assert.NotNil(t, resolver.records)
assert.False(t, resolver.MatchSubdomains())
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string
urlStr string
expectedDom string
expectError bool
}{
{
name: "HTTPS URL with port",
urlStr: "https://api.netbird.io:443",
expectedDom: "api.netbird.io",
expectError: false,
},
{
name: "HTTP URL without port",
urlStr: "http://signal.example.com",
expectedDom: "signal.example.com",
expectError: false,
},
{
name: "URL with path",
urlStr: "https://relay.netbird.io/status",
expectedDom: "relay.netbird.io",
expectError: false,
},
{
name: "Invalid URL",
urlStr: "not-a-valid-url",
expectedDom: "not-a-valid-url",
expectError: false,
},
{
name: "Empty URL",
urlStr: "",
expectedDom: "",
expectError: true,
},
{
name: "STUN URL",
urlStr: "stun:stun.example.com:3478",
expectedDom: "stun.example.com",
expectError: false,
},
{
name: "TURN URL",
urlStr: "turn:turn.example.com:3478",
expectedDom: "turn.example.com",
expectError: false,
},
{
name: "REL URL",
urlStr: "rel://relay.example.com:443",
expectedDom: "relay.example.com",
expectError: false,
},
{
name: "RELS URL",
urlStr: "rels://relay.example.com:443",
expectedDom: "relay.example.com",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var parsedURL *url.URL
var err error
if tt.urlStr != "" {
parsedURL, err = url.Parse(tt.urlStr)
if err != nil && !tt.expectError {
t.Fatalf("Failed to parse URL: %v", err)
}
}
domain, err := extractDomainFromURL(parsedURL)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedDom, domain.SafeString())
}
})
}
}
func TestResolver_PopulateFromConfig(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := NewResolver()
// Test with IP address - should return error since IP addresses are rejected
mgmtURL, _ := url.Parse("https://127.0.0.1")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.Error(t, err)
assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed)
// No domains should be cached when using IP addresses
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_ServeDNS(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Add a test domain to the cache - use example.org which is reserved for testing
testDomain, err := domain.FromString("example.org")
if err != nil {
t.Fatalf("Failed to create domain: %v", err)
}
err = resolver.AddDomain(ctx, testDomain)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Test A record query for cached domain
t.Run("Cached domain A record", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer")
})
// Test uncached domain signals to continue to next handler
t.Run("Uncached domain signals continue to next handler", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
req := new(dns.Msg)
req.SetQuestion("unknown.example.com.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
// Zero flag set to true signals the handler chain to continue to next handler
assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain")
})
// Test that subdomains of cached domains are NOT resolved
t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// Query for a subdomain of our cached domain
req := new(dns.Msg)
req.SetQuestion("sub.example.org.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains")
})
// Test case-insensitive matching
t.Run("Case-insensitive domain matching", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// Query with different casing
req := new(dns.Msg)
req.SetQuestion("EXAMPLE.ORG.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case")
})
}
func TestResolver_GetCachedDomains(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
testDomain, err := domain.FromString("example.org")
if err != nil {
t.Fatalf("Failed to create domain: %v", err)
}
err = resolver.AddDomain(ctx, testDomain)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
cachedDomains := resolver.GetCachedDomains()
assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain")
assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original")
assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot")
}
func TestResolver_ManagementDomainProtection(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
mgmtURL, _ := url.Parse("https://example.org")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
initialDomains := resolver.GetCachedDomains()
if len(initialDomains) == 0 {
t.Skip("Management domain failed to resolve, skipping test")
}
assert.Equal(t, 1, len(initialDomains), "Should have management domain cached")
assert.Equal(t, "example.org", initialDomains[0].SafeString())
serverDomains := dnsconfig.ServerDomains{
Signal: "google.com",
Relay: []domain.Domain{"cloudflare.com"},
}
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
if err != nil {
t.Logf("Server domains update failed: %v", err)
}
finalDomains := resolver.GetCachedDomains()
managementStillCached := false
for _, d := range finalDomains {
if d.SafeString() == "example.org" {
managementStillCached = true
break
}
}
assert.True(t, managementStillCached, "Management domain should never be removed")
}
// extractDomainFromURL extracts a domain from a URL - test helper function
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil {
return "", fmt.Errorf("URL is nil")
}
return dnsconfig.ExtractValidDomain(u.String())
}
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Verify domains were added
cachedDomains := resolver.GetCachedDomains()
assert.Len(t, cachedDomains, 3)
// Update with empty ServerDomains (simulating partial network map update)
emptyDomains := dnsconfig.ServerDomains{}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains)
assert.NoError(t, err)
// Verify no domains were removed
assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty")
// Verify all original domains are still cached
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 3, "All original domains should still be cached")
}
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial complete domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
partialDomains := dnsconfig.ServerDomains{
Signal: "github.com",
}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Should remove only the old signal domain
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
assert.Equal(t, "example.org", removedDomains[0].SafeString())
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
domainStrings[i] = d.SafeString()
}
assert.Contains(t, domainStrings, "github.com")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.NotContains(t, domainStrings, "example.org")
}
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial complete domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
partialDomains := dnsconfig.ServerDomains{
Flow: "github.com",
}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
domainStrings[i] = d.SafeString()
}
assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.Contains(t, domainStrings, "github.com")
}

View File

@@ -3,20 +3,23 @@ package dns
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"net/url"
"github.com/miekg/dns" "github.com/miekg/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
// MockServer is the mock instance of a dns server // MockServer is the mock instance of a dns server
type MockServer struct { type MockServer struct {
InitializeFunc func() error InitializeFunc func() error
StopFunc func() StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func(domain.List, dns.Handler, int) RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func(domain.List, int) DeregisterHandlerFunc func(domain.List, int)
UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
} }
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -70,3 +73,14 @@ func (m *MockServer) SearchDomains() []string {
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() { func (m *MockServer) ProbeAvailability() {
} }
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
}
return nil
}
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"net/url"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@@ -15,7 +16,9 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -45,6 +48,8 @@ type Server interface {
OnUpdatedHostDNSServer(addrs []netip.AddrPort) OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
} }
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
@@ -77,6 +82,8 @@ type DefaultServer struct {
handlerChain *HandlerChain handlerChain *HandlerChain
extraDomains map[domain.Domain]int extraDomains map[domain.Domain]int
mgmtCacheResolver *mgmt.Resolver
// permanent related properties // permanent related properties
permanent bool permanent bool
hostsDNSHolder *hostsDNSHolder hostsDNSHolder *hostsDNSHolder
@@ -104,18 +111,20 @@ type handlerWrapper struct {
type registeredHandlerMap map[types.HandlerID]handlerWrapper type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
WgInterface WGIface
CustomAddress string
StatusRecorder *peer.Status
StateManager *statemanager.Manager
DisableSys bool
}
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) {
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
stateManager *statemanager.Manager,
disableSys bool,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if config.CustomAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
} }
@@ -123,13 +132,14 @@ func NewDefaultServer(
} }
var dnsService service var dnsService service
if wgInterface.IsUserspaceBind() { if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(wgInterface) dnsService = NewServiceViaMemory(config.WgInterface)
} else { } else {
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(config.WgInterface, addrPort)
} }
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
return server, nil
} }
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@@ -178,20 +188,24 @@ func newDefaultServer(
) *DefaultServer { ) *DefaultServer {
handlerChain := NewHandlerChain() handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
disableSys: disableSys, disableSys: disableSys,
service: dnsService, service: dnsService,
handlerChain: handlerChain, handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(), localResolver: local.NewResolver(),
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager, stateManager: stateManager,
hostsDNSHolder: newHostsDNSHolder(), hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
} }
// register with root zone, handler chain takes care of the routing // register with root zone, handler chain takes care of the routing
@@ -217,7 +231,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority) log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains)
for _, domain := range domains { for _, domain := range domains {
if domain == "" { if domain == "" {
@@ -246,7 +260,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
} }
func (s *DefaultServer) deregisterHandler(domains []string, priority int) { func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority) log.Debugf("deregistering handler with priority %d for %v", priority, domains)
for _, domain := range domains { for _, domain := range domains {
if domain == "" { if domain == "" {
@@ -432,6 +446,29 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Wait() wg.Wait()
} }
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
s.mux.Lock()
defer s.mux.Unlock()
if s.mgmtCacheResolver != nil {
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
if err != nil {
return fmt.Errorf("update management cache resolver: %w", err)
}
if len(removedDomains) > 0 {
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
if len(newDomains) > 0 {
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
}
}
return nil
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver // is the service should be Disabled, we stop the listener or fake resolver
if update.ServiceEnable { if update.ServiceEnable {
@@ -961,3 +998,11 @@ func toZone(d domain.Domain) domain.Domain {
), ),
) )
} }
// PopulateManagementDomain populates the DNS cache with management domain
func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
if s.mgmtCacheResolver != nil {
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
}
return nil
}

View File

@@ -363,7 +363,13 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -473,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return return
} }
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
return return
@@ -575,7 +587,13 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: &mocWGIface{},
CustomAddress: testCase.addrPort,
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }

View File

@@ -33,9 +33,11 @@ func SetCurrentMTU(mtu uint16) {
} }
const ( const (
UpstreamTimeout = 15 * time.Second UpstreamTimeout = 4 * time.Second
// ClientTimeout is the timeout for the dns.Client.
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
failsTillDeact = int32(5)
reactivatePeriod = 30 * time.Second reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second probeTimeout = 2 * time.Second
) )
@@ -58,9 +60,7 @@ type upstreamResolverBase struct {
upstreamServers []netip.AddrPort upstreamServers []netip.AddrPort
domain string domain string
disabled bool disabled bool
failsCount atomic.Int32
successCount atomic.Int32 successCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex mutex sync.Mutex
reactivatePeriod time.Duration reactivatePeriod time.Duration
upstreamTimeout time.Duration upstreamTimeout time.Duration
@@ -79,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
domain: domain, domain: domain,
upstreamTimeout: UpstreamTimeout, upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
} }
} }
// String returns a string representation of the upstream resolver // String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string { func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %s", u.upstreamServers) return fmt.Sprintf("Upstream %s", u.upstreamServers)
} }
// ID returns the unique handler ID // ID returns the unique handler ID
@@ -116,58 +115,102 @@ func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID() requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID) logger := log.WithField("request_id", requestID)
var err error
defer func() {
u.checkUpstreamFails(err)
}()
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
if u.ctx.Err() != nil {
logger.Tracef("%s has been stopped", u)
return
}
if u.tryUpstreamServers(w, r, logger) {
return
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
}
select { func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
case <-u.ctx.Done(): timeout := u.upstreamTimeout
logger.Tracef("%s has been stopped", u) if len(u.upstreamServers) > 1 {
return maxTotal := 5 * time.Second
default: minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
var rm *dns.Msg if u.queryUpstream(w, r, upstream, timeout, logger) {
var t time.Duration return true
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
} }
}
return false
}
if rm == nil || !rm.Response { func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) var rm *dns.Msg
continue var t time.Duration
} var err error
u.successCount.Add(1) var startTime time.Time
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err = w.WriteMsg(rm); err != nil { if err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
} return false
// count the fails only if they happen sequentially }
u.failsCount.Store(0)
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
}
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return return
} }
u.failsCount.Add(1)
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
@@ -177,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// checkUpstreamFails counts fails and disables or enables upstream resolving
//
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
u.mutex.Lock()
defer u.mutex.Unlock()
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
return
}
select {
case <-u.ctx.Done():
return
default:
}
u.disable(err)
if u.statusRecorder == nil {
return
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta
)
}
// ProbeAvailability tests all upstream servers simultaneously and // ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work // disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() { func (u *upstreamResolverBase) ProbeAvailability() {
@@ -224,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() {
default: default:
} }
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact // avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { if u.successCount.Load() > 0 {
return return
} }
@@ -312,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0)
u.successCount.Add(1) u.successCount.Add(1)
u.reactivate() u.reactivate()
u.disabled = false u.disabled = false
@@ -416,3 +423,80 @@ func GenerateRequestID() string {
} }
return hex.EncodeToString(bytes) return hex.EncodeToString(bytes)
} }
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
switch {
case !isConnected:
statusInfo += " DISCONNECTED"
case !hasRecentHandshake:
statusInfo += " NO_RECENT_HANDSHAKE"
default:
statusInfo += " connected"
}
if !peerState.LastWireguardHandshake.IsZero() {
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
} else {
statusInfo += " no_handshake"
}
if peerState.Relayed {
statusInfo += " via_relay"
}
if peerState.Latency > 0 {
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
}
return statusInfo
}
// findPeerForIP finds which peer handles the given IP address
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
if statusRecorder == nil {
return nil
}
fullStatus := statusRecorder.GetFullStatus()
var bestMatch *peer.State
var bestPrefixLen int
for _, peerState := range fullStatus.Peers {
routes := peerState.GetRoutes()
for route := range routes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
continue
}
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
peerStateCopy := peerState
bestMatch = &peerStateCopy
bestPrefixLen = prefix.Bits()
}
}
}
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}

View File

@@ -50,7 +50,9 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
} }
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{} upstreamExchangeClient := &dns.Client{
Timeout: ClientTimeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }
@@ -72,10 +74,11 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
} }
upstreamExchangeClient := &dns.Client{ upstreamExchangeClient := &dns.Client{
Dialer: dialer, Dialer: dialer,
Timeout: timeout,
} }
return upstreamExchangeClient.Exchange(r, upstream) return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }
func (u *upstreamResolver) isLocalResolver(upstream string) bool { func (u *upstreamResolver) isLocalResolver(upstream string) bool {

View File

@@ -34,7 +34,10 @@ func newUpstreamResolver(
} }
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) client := &dns.Client{
Timeout: ClientTimeout,
}
return ExchangeWithFallback(ctx, client, r, upstream)
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {

View File

@@ -47,7 +47,9 @@ func newUpstreamResolver(
} }
func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
client := &dns.Client{} client := &dns.Client{
Timeout: ClientTimeout,
}
upstreamHost, _, err := net.SplitHostPort(upstream) upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
@@ -110,7 +112,8 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura
}, },
} }
client := &dns.Client{ client := &dns.Client{
Dialer: dialer, Dialer: dialer,
Timeout: dialTimeout,
} }
return client, nil return client, nil
} }

View File

@@ -124,29 +124,26 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
} }
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
r: new(dns.Msg),
rtt: time.Millisecond,
}
resolver := &upstreamResolverBase{ resolver := &upstreamResolverBase{
ctx: context.TODO(), ctx: context.TODO(),
upstreamClient: &mockUpstreamResolver{ upstreamClient: mockClient,
err: nil,
r: new(dns.Msg),
rtt: time.Millisecond,
},
upstreamTimeout: UpstreamTimeout, upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: time.Microsecond * 100,
failsTillDeact: failsTillDeact,
} }
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil },
}
failed := false failed := false
resolver.deactivate = func(error) { resolver.deactivate = func(error) {
failed = true failed = true
// After deactivation, make the mock client work again
mockClient.err = nil
} }
reactivated := false reactivated := false
@@ -154,7 +151,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivated = true reactivated = true
} }
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) resolver.ProbeAvailability()
if !failed { if !failed {
t.Errorf("expected that resolving was deactivated") t.Errorf("expected that resolving was deactivated")
@@ -173,11 +170,6 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
return return
} }
if resolver.failsCount.Load() != 0 {
t.Errorf("fails count after reactivation should be 0")
return
}
if resolver.disabled { if resolver.disabled {
t.Errorf("should be enabled") t.Errorf("should be enabled")
} }

View File

@@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
"net/url"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
@@ -17,8 +18,8 @@ import (
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -28,11 +29,12 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -164,7 +166,7 @@ type Engine struct {
wgInterface WGIface wgInterface WGIface
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
@@ -345,7 +347,7 @@ func (e *Engine) Stop() error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error { func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -401,6 +403,11 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
// Populate DNS cache with NetbirdConfig and management URL for early resolution
if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache: %v", err)
}
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Context: e.ctx, Context: e.ctx,
PublicKey: e.config.WgPrivateKey.PublicKey().String(), PublicKey: e.config.WgPrivateKey.PublicKey().String(),
@@ -454,7 +461,7 @@ func (e *Engine) Start() error {
StunTurn: &e.stunTurn, StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList, InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault, UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux, UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
} }
@@ -661,6 +668,30 @@ func (e *Engine) removePeer(peerKey string) error {
return nil return nil
} }
// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response
func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
if e.dnsServer == nil {
return nil
}
// Populate management URL if provided
if mgmtURL != nil {
if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache with management URL: %v", err)
}
}
// Populate NetbirdConfig domains if provided
if netbirdConfig != nil {
serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err)
}
}
return nil
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -692,6 +723,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err) return fmt.Errorf("handle the flow configuration: %w", err)
} }
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal // todo update signal
} }
@@ -914,7 +949,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.LazyConnectionEnabled, e.config.LazyConnectionEnabled,
) )
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
@@ -925,7 +959,7 @@ func (e *Engine) receiveManagementEvents() {
} }
log.Debugf("stopped receiving updates from Management Service") log.Debugf("stopped receiving updates from Management Service")
}() }()
log.Debugf("connecting to Management Service updates stream") log.Infof("connecting to Management Service updates stream")
} }
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
@@ -1292,7 +1326,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
StunTurn: &e.stunTurn, StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList, InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault, UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux, UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
}, },
@@ -1557,7 +1591,14 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil return dnsServer, nil
default: default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{
WgInterface: e.wgInterface,
CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder,
StateManager: e.stateManager,
DisableSys: e.config.DisableDNS,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -19,21 +19,18 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
@@ -45,9 +42,12 @@ import (
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -85,7 +85,7 @@ type MockWGIface struct {
NameFunc func() string NameFunc func() string
AddressFunc func() wgaddr.Address AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error RemovePeerFunc func(peerKey string) error
@@ -135,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc() return m.ToInterfaceFunc()
} }
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return m.UpFunc() return m.UpFunc()
} }
@@ -266,7 +266,7 @@ func TestEngine_SSH(t *testing.T) {
}, },
}, nil }, nil
} }
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -414,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
@@ -612,7 +612,7 @@ func TestEngine_Sync(t *testing.T) {
} }
}() }()
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1069,7 +1069,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
defer mu.Unlock() defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String()) guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid) device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err) t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done() wg.Done()
@@ -1555,7 +1555,11 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
@@ -1572,7 +1576,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
Return(&types.ExtraSettings{}, nil). Return(&types.ExtraSettings{}, nil).
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)

View File

@@ -9,9 +9,9 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
Name() string Name() string
Address() wgaddr.Address Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error

View File

@@ -40,7 +40,7 @@ func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool,
return false, err return false, err
} }
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) { if isLoginNeeded(err) {
return true, nil return true, nil
} }
@@ -69,14 +69,18 @@ func Login(ctx context.Context, config *profilemanager.Config, setupKey string,
return err return err
} }
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) { if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required") log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err return err
} }
return err return nil
} }
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
@@ -101,11 +105,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
return mgmClient, err return mgmClient, err
} }
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) { func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey() serverKey, err := mgmClient.GetServerPublicKey()
if err != nil { if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err) log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err return nil, nil, err
} }
sysInfo := system.GetInfo(ctx) sysInfo := system.GetInfo(ctx)
@@ -121,8 +125,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.BlockInbound, config.BlockInbound,
config.LazyConnectionEnabled, config.LazyConnectionEnabled,
) )
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, err return serverKey, loginResp, err
} }
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.

View File

@@ -6,12 +6,11 @@ import (
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
"os"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -174,7 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" { if !isForceRelayed() {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
} }

View File

@@ -0,0 +1,14 @@
package peer
import (
"os"
"strings"
)
const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func isForceRelayed() bool {
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}

View File

@@ -6,7 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"

View File

@@ -43,13 +43,6 @@ type OfferAnswer struct {
SessionID *ICESessionID SessionID *ICESessionID
} }
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
}
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
@@ -57,7 +50,7 @@ type Handshaker struct {
signaler *Signaler signaler *Signaler
ice *WorkerICE ice *WorkerICE
relay *WorkerRelay relay *WorkerRelay
onNewOfferListeners []func(*OfferAnswer) onNewOfferListeners []*OfferListener
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
@@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
} }
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.onNewOfferListeners = append(h.onNewOfferListeners, offer) l := NewOfferListener(offer)
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
} }
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
@@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue continue
} }
for _, listener := range h.onNewOfferListeners { for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer) listener.Notify(&remoteOfferAnswer)
} }
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners { for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer) listener.Notify(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers") h.log.Infof("stop listening for remote offers and answers")

View File

@@ -0,0 +1,62 @@
package peer
import (
"sync"
)
type callbackFunc func(remoteOfferAnswer *OfferAnswer)
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
}
type OfferListener struct {
fn callbackFunc
running bool
latest *OfferAnswer
mu sync.Mutex
}
func NewOfferListener(fn callbackFunc) *OfferListener {
return &OfferListener{
fn: fn,
}
}
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock()
defer o.mu.Unlock()
// Store the latest offer
o.latest = remoteOfferAnswer
// If already running, the running goroutine will pick up this latest value
if o.running {
return
}
// Start processing
o.running = true
// Process in a goroutine to avoid blocking the caller
go func(remoteOfferAnswer *OfferAnswer) {
for {
o.fn(remoteOfferAnswer)
o.mu.Lock()
if o.latest == nil {
// No more work to do
o.running = false
o.mu.Unlock()
return
}
remoteOfferAnswer = o.latest
// Clear the latest to mark it as being processed
o.latest = nil
o.mu.Unlock()
}
}(remoteOfferAnswer)
}

View File

@@ -0,0 +1,39 @@
package peer
import (
"testing"
"time"
)
func Test_newOfferListener(t *testing.T) {
dummyOfferAnswer := &OfferAnswer{}
runChan := make(chan struct{}, 10)
longRunningFn := func(remoteOfferAnswer *OfferAnswer) {
time.Sleep(1 * time.Second)
runChan <- struct{}{}
}
hl := NewOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer)
// Wait for exactly 2 callbacks
for i := 0; i < 2; i++ {
select {
case <-runChan:
case <-time.After(3 * time.Second):
t.Fatal("Timeout waiting for callback")
}
}
// Verify no additional callbacks happen
select {
case <-runChan:
t.Fatal("Unexpected additional callback")
case <-time.After(100 * time.Millisecond):
t.Log("Correctly received exactly 2 callbacks")
}
}

View File

@@ -3,7 +3,7 @@ package ice
import ( import (
"sync/atomic" "sync/atomic"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
) )
type StunTurn atomic.Value type StunTurn atomic.Value

View File

@@ -4,7 +4,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/randutil" "github.com/pion/randutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"

View File

@@ -1,7 +1,7 @@
package ice package ice
import ( import (
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
) )
type Config struct { type Config struct {

View File

@@ -1,7 +1,7 @@
package peer package peer
import ( import (
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"

View File

@@ -30,9 +30,10 @@ type WGWatcher struct {
peerKey string peerKey string
stateDump *stateDump stateDump *stateDump
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxLock sync.Mutex ctxLock sync.Mutex
enabledTime time.Time
} }
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher") w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock() w.ctxLock.Lock()
w.enabledTime = time.Now()
if w.ctx != nil && w.ctx.Err() == nil { if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled") w.log.Errorf("WireGuard watcher already enabled")
@@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
onDisconnectedFn() onDisconnectedFn()
return return
} }
if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds()
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}
lastHandshake = *handshake lastHandshake = *handshake
resetTime := time.Until(handshake.Add(checkPeriod)) resetTime := time.Until(handshake.Add(checkPeriod))

View File

@@ -8,12 +8,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype" "github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -55,10 +54,6 @@ type WorkerICE struct {
sessionID ICESessionID sessionID ICESessionID
muxAgent sync.Mutex muxAgent sync.Mutex
StunTurn []*stun.URI
sentExtraSrflx bool
localUfrag string localUfrag string
localPwd string localPwd string
@@ -122,7 +117,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
w.agent = nil w.agent = nil
// todo consider to switch to Relay connection while establishing a new ICE connection
} }
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
@@ -140,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
w.sentExtraSrflx = false
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel w.agentDialerCancel = dialerCancel
w.agentConnecting = true w.agentConnecting = true
@@ -167,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
w.log.Errorf("error while handling remote candidate") w.log.Errorf("error while handling remote candidate")
return return
} }
if shouldAddExtraCandidate(candidate) {
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
} }
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
@@ -214,10 +222,6 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
return nil, err return nil, err
} }
if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil return agent, nil
} }
@@ -253,6 +257,11 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.closeAgent(agent, w.agentDialerCancel) w.closeAgent(agent, w.agentDialerCancel)
return return
} }
if pair == nil {
w.log.Warnf("selected candidate pair is nil, cannot proceed")
w.closeAgent(agent, w.agentDialerCancel)
return
}
if !isRelayCandidate(pair.Local) { if !isRelayCandidate(pair.Local) {
// dynamically set remote WireGuard port if other side specified a different one from the default one // dynamically set remote WireGuard port if other side specified a different one from the default one
@@ -327,7 +336,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
return return
} }
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
if !ok { if !ok {
w.log.Warn("invalid udp mux conversion") w.log.Warn("invalid udp mux conversion")
return return
@@ -354,31 +363,32 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
} }
}() }()
if !w.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
w.sentExtraSrflx = true
go func() {
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
}
}()
} }
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key) w.config.Key)
w.muxAgent.Lock()
pair, err := w.agent.GetSelectedCandidatePair()
if err != nil {
w.log.Warnf("failed to get selected candidate pair: %s", err)
w.muxAgent.Unlock()
return
}
if pair == nil {
w.log.Warnf("selected candidate pair is nil, cannot proceed")
w.muxAgent.Unlock()
return
}
w.muxAgent.Unlock()
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
} }
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
@@ -388,7 +398,10 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
case ice.ConnectionStateConnected: case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected w.lastKnownState = ice.ConnectionStateConnected
return return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected()
@@ -400,32 +413,34 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
} }
} }
func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) {
if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key if isController(w.config) {
if isControlling { return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else { } else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} }
} }
func shouldAddExtraCandidate(candidate ice.Candidate) bool {
if candidate.Type() != ice.CandidateTypeServerReflexive {
return false
}
if candidate.Port() == candidate.RelatedAddress().Port {
return false
}
// in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
// in newer version we generate locally the extra candidate
if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
return false
}
return true
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress() relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(), Network: candidate.NetworkType().String(),
Address: candidate.Address(), Address: candidate.Address(),
Port: relatedAdd.Port, Port: relatedAdd.Port,
@@ -433,6 +448,21 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
RelAddr: relatedAdd.Address, RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port, RelPort: relatedAdd.Port,
}) })
if err != nil {
return nil, err
}
for _, e := range candidate.Extensions() {
// overwrite the original candidate ID with the new one to avoid candidate duplication
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = candidate.ID()
}
if err := ec.AddExtension(e); err != nil {
return nil, err
}
}
return ec, nil
} }
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {

View File

@@ -7,7 +7,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/turn/v3" "github.com/pion/turn/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"

View File

@@ -2,11 +2,13 @@ package dnsinterceptor
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -26,6 +28,8 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const dnsTimeout = 8 * time.Second
type domainMap map[domain.Domain][]netip.Prefix type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface { type internalDNATer interface {
@@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil { if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return return
@@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil { if err != nil {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
@@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
} }
return return
} }
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""
}
peerState, err := d.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
}
return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
}

View File

@@ -40,7 +40,7 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
if netstack.IsEnabled() { if netstack.IsEnabled() {
n.iFaceDiscover = pionDiscover{} n.iFaceDiscover = pionDiscover{}
} else { } else {
newMobileIFaceDiscover(iFaceDiscover) n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover)
} }
return n, n.UpdateInterfaces() return n, n.UpdateInterfaces()
} }

View File

@@ -65,6 +65,8 @@ type Server struct {
mutex sync.Mutex mutex sync.Mutex
config *profilemanager.Config config *profilemanager.Config
proto.UnimplementedDaemonServiceServer proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
clientRunningChan chan struct{}
connectClient *internal.ConnectClient connectClient *internal.ConnectClient
@@ -103,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
func (s *Server) Start() error { func (s *Server) Start() error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil { if err := handlePanicLog(); err != nil {
@@ -172,8 +175,12 @@ func (s *Server) Start() error {
return nil return nil
} }
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) if s.clientRunning {
return nil
}
s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan)
return nil return nil
} }
@@ -204,12 +211,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost. // mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) {
runningChan chan struct{}, defer func() {
) { s.mutex.Lock()
backOff := getConnectWithBackoff(ctx) s.clientRunning = false
retryStarted := false s.mutex.Unlock()
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
return
}
backOff := getConnectWithBackoff(ctx)
go func() { go func() {
t := time.NewTicker(24 * time.Hour) t := time.NewTicker(24 * time.Hour)
for { for {
@@ -218,91 +235,34 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage
t.Stop() t.Stop()
return return
case <-t.C: case <-t.C:
if retryStarted { mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState()
mgmtState := statusRecorder.GetManagementState() if mgmtState.Connected && signalState.Connected {
signalState := statusRecorder.GetSignalState() log.Tracef("resetting status")
if mgmtState.Connected && signalState.Connected { backOff.Reset()
log.Tracef("resetting status") } else {
retryStarted = false log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
} else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
}
} }
} }
} }
}() }()
runOperation := func() error { runOperation := func() error {
log.Tracef("running client connection") err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
err := s.connectClient.Run(runningChan)
if err != nil { if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err) log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
} }
if config.DisableAutoConnect { log.Tracef("client connection exited gracefully, do not need to retry")
return backoff.Permanent(err) return nil
}
if !retryStarted {
retryStarted = true
backOff.Reset()
}
log.Tracef("client connection exited")
return fmt.Errorf("client connection exited")
} }
err := backoff.Retry(runOperation, backOff) if err := backoff.Retry(runOperation, backOff); err != nil {
if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { log.Errorf("operation failed: %v", err)
log.Errorf("received an error when trying to connect: %v", err)
} else {
log.Tracef("retry canceled")
} }
} }
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails // loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) { func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
var status internal.StatusType var status internal.StatusType
@@ -716,11 +676,14 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel() defer cancel()
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal if !s.clientRunning {
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan)
}
for { for {
select { select {
case <-runningChan: case <-s.clientRunningChan:
s.isSessionActive.Store(true) s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil return &proto.UpResponse{}, nil
case <-callerCtx.Done(): case <-callerCtx.Done():
@@ -1127,6 +1090,134 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
}, nil }, nil
} }
// AddProfile adds a new profile to the daemon.
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" || msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
}
profiles, err := s.profileManager.ListProfiles(msg.Username)
if err != nil {
log.Errorf("failed to list profiles: %v", err)
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
response := &proto.ListProfilesResponse{
Profiles: make([]*proto.Profile, len(profiles)),
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Name: profile.Name,
IsActive: profile.IsActive,
}
}
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
activeProfile, err := s.profileManager.GetActiveProfileState()
if err != nil {
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err
}
return nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}
return false
}
func (s *Server) onSessionExpire() { func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp() isUIActive := internal.CheckUIApp()
@@ -1138,6 +1229,45 @@ func (s *Server) onSessionExpire() {
} }
} }
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{ pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{}, ManagementState: &proto.ManagementState{},
@@ -1252,121 +1382,3 @@ func sendTerminalNotification() error {
return wallCmd.Wait() return wallCmd.Wait()
} }
// AddProfile adds a new profile to the daemon.
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" || msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
}
profiles, err := s.profileManager.ListProfiles(msg.Username)
if err != nil {
log.Errorf("failed to list profiles: %v", err)
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
response := &proto.ListProfilesResponse{
Profiles: make([]*proto.Profile, len(profiles)),
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Name: profile.Name,
IsActive: profile.IsActive,
}
}
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
activeProfile, err := s.profileManager.GetActiveProfileState()
if err != nil {
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}
return false
}

View File

@@ -10,25 +10,25 @@ import (
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto" daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -294,15 +294,20 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)

View File

@@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"strings" "strings"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
@@ -95,14 +96,6 @@ func (i *Info) SetFlags(
i.LazyConnectionEnabled = lazyConnectionEnabled i.LazyConnectionEnabled = lazyConnectionEnabled
} }
// StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string { func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx) md, hasMeta := metadata.FromOutgoingContext(ctx)
@@ -180,6 +173,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
// GetInfoWithChecks retrieves and parses the system information with applied checks. // GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))
processCheckPaths := make([]string, 0) processCheckPaths := make([]string, 0)
for _, check := range checks { for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...) processCheckPaths = append(processCheckPaths, check.GetFiles()...)
@@ -189,16 +183,11 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("gathering process check information completed")
info := GetInfo(ctx) info := GetInfo(ctx)
info.Files = files info.Files = files
log.Debugf("all system information gathered successfully")
return info, nil return info, nil
} }
// UpdateStaticInfo asynchronously updates static system and platform information
func UpdateStaticInfo() {
go func() {
_ = updateStaticInfo()
}()
}

View File

@@ -15,6 +15,11 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
kernel := "android" kernel := "android"

View File

@@ -19,6 +19,10 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
func UpdateStaticInfoAsync() {
go updateStaticInfo()
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
utsname := unix.Utsname{} utsname := unix.Utsname{}
@@ -41,7 +45,7 @@ func GetInfo(ctx context.Context) *Info {
} }
start := time.Now() start := time.Now()
si := updateStaticInfo() si := getStaticInfo()
if time.Since(start) > 1*time.Second { if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start)) log.Warnf("updateStaticInfo took %s", time.Since(start))
} }

View File

@@ -18,6 +18,11 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
out := _getInfo() out := _getInfo()

View File

@@ -10,6 +10,11 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {

View File

@@ -23,6 +23,10 @@ var (
getSystemInfo = defaultSysInfoImplementation getSystemInfo = defaultSysInfoImplementation
) )
func UpdateStaticInfoAsync() {
go updateStaticInfo()
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
info := _getInfo() info := _getInfo()
@@ -48,7 +52,7 @@ func GetInfo(ctx context.Context) *Info {
} }
start := time.Now() start := time.Now()
si := updateStaticInfo() si := getStaticInfo()
if time.Since(start) > 1*time.Second { if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start)) log.Warnf("updateStaticInfo took %s", time.Since(start))
} }

View File

@@ -2,187 +2,51 @@ package system
import ( import (
"context" "context"
"fmt"
"os" "os"
"runtime" "runtime"
"strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
type Win32_OperatingSystem struct { func UpdateStaticInfoAsync() {
Caption string go updateStaticInfo()
}
type Win32_ComputerSystem struct {
Manufacturer string
}
type Win32_ComputerSystemProduct struct {
Name string
}
type Win32_BIOS struct {
SerialNumber string
} }
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
osName, osVersion := getOSNameAndVersion()
buildVersion := getBuildVersion()
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
}
start := time.Now() start := time.Now()
si := updateStaticInfo() si := getStaticInfo()
if time.Since(start) > 1*time.Second { if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start)) log.Warnf("updateStaticInfo took %s", time.Since(start))
} }
gio := &Info{ gio := &Info{
Kernel: "windows", Kernel: "windows",
OSVersion: osVersion, OSVersion: si.OSVersion,
Platform: "unknown", Platform: "unknown",
OS: osName, OS: si.OSName,
GoOS: runtime.GOOS, GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
KernelVersion: buildVersion, KernelVersion: si.BuildVersion,
NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName, SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment, Environment: si.Environment,
} }
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
} else {
gio.NetworkAddresses = addrs
}
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname) gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.NetbirdVersion = version.NetbirdVersion() gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)
return gio return gio
} }
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
log.Error(err)
return "Windows", getBuildVersion()
}
if len(dst) == 0 {
return "Windows", getBuildVersion()
}
split := strings.Split(dst[0].Caption, " ")
if len(split) <= 3 {
return "Windows", getBuildVersion()
}
name := split[1]
version := split[2]
if split[2] == "Server" {
name = fmt.Sprintf("%s %s", split[1], split[2])
version = split[3]
}
return name, version
}
func getBuildVersion() string {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
if err != nil {
log.Error(err)
return "0.0.0.0"
}
defer func() {
deferErr := k.Close()
if deferErr != nil {
log.Error(deferErr)
}
}()
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
if err != nil {
log.Error(err)
}
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
if err != nil {
log.Error(err)
}
build, _, err := k.GetStringValue("CurrentBuildNumber")
if err != nil {
log.Error(err)
}
// Update Build Revision
ubr, _, err := k.GetIntegerValue("UBR")
if err != nil {
log.Error(err)
}
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
return ver
}
func sysNumber() (string, error) {
var dst []Win32_BIOS
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].SerialNumber, nil
}
func sysProductName() (string, error) {
var dst []Win32_ComputerSystemProduct
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
// `ComputerSystemProduct` could be empty on some virtualized systems
if len(dst) < 1 {
return "unknown", nil
}
return dst[0].Name, nil
}
func sysManufacturer() (string, error) {
var dst []Win32_ComputerSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].Manufacturer, nil
}

View File

@@ -3,12 +3,7 @@
package system package system
import ( import (
"context"
"sync" "sync"
"time"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
) )
var ( var (
@@ -16,25 +11,26 @@ var (
once sync.Once once sync.Once
) )
func updateStaticInfo() StaticInfo { // StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
// Windows specific fields
OSName string
OSVersion string
BuildVersion string
}
func updateStaticInfo() {
once.Do(func() { once.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) staticInfo = newStaticInfo()
defer cancel()
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
wg.Done()
}()
go func() {
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
go func() {
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Wait()
}) })
}
func getStaticInfo() StaticInfo {
updateStaticInfo()
return staticInfo return staticInfo
} }

View File

@@ -1,8 +0,0 @@
//go:build android || freebsd || ios
package system
// updateStaticInfo returns an empty implementation for unsupported platforms
func updateStaticInfo() StaticInfo {
return StaticInfo{}
}

View File

@@ -0,0 +1,35 @@
//go:build (linux && !android) || (darwin && !ios)
package system
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
)
func newStaticInfo() StaticInfo {
si := StaticInfo{}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo()
wg.Done()
}()
go func() {
si.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
go func() {
si.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Wait()
return si
}

View File

@@ -0,0 +1,184 @@
package system
import (
"context"
"fmt"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
)
type Win32_OperatingSystem struct {
Caption string
}
type Win32_ComputerSystem struct {
Manufacturer string
}
type Win32_ComputerSystemProduct struct {
Name string
}
type Win32_BIOS struct {
SerialNumber string
}
func newStaticInfo() StaticInfo {
si := StaticInfo{}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo()
wg.Done()
}()
wg.Add(1)
go func() {
si.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
wg.Add(1)
go func() {
si.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Add(1)
go func() {
si.OSName, si.OSVersion = getOSNameAndVersion()
wg.Done()
}()
wg.Add(1)
go func() {
si.BuildVersion = getBuildVersion()
wg.Done()
}()
wg.Wait()
return si
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func sysNumber() (string, error) {
var dst []Win32_BIOS
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].SerialNumber, nil
}
func sysProductName() (string, error) {
var dst []Win32_ComputerSystemProduct
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
// `ComputerSystemProduct` could be empty on some virtualized systems
if len(dst) < 1 {
return "unknown", nil
}
return dst[0].Name, nil
}
func sysManufacturer() (string, error) {
var dst []Win32_ComputerSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].Manufacturer, nil
}
func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
log.Error(err)
return "Windows", getBuildVersion()
}
if len(dst) == 0 {
return "Windows", getBuildVersion()
}
split := strings.Split(dst[0].Caption, " ")
if len(split) <= 3 {
return "Windows", getBuildVersion()
}
name := split[1]
version := split[2]
if split[2] == "Server" {
name = fmt.Sprintf("%s %s", split[1], split[2])
version = split[3]
}
return name, version
}
func getBuildVersion() string {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
if err != nil {
log.Error(err)
return "0.0.0.0"
}
defer func() {
deferErr := k.Close()
if deferErr != nil {
log.Error(deferErr)
}
}()
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
if err != nil {
log.Error(err)
}
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
if err != nil {
log.Error(err)
}
build, _, err := k.GetStringValue("CurrentBuildNumber")
if err != nil {
log.Error(err)
}
// Update Build Revision
ubr, _, err := k.GetIntegerValue("UBR")
if err != nil {
log.Error(err)
}
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
return ver
}

20
go.mod
View File

@@ -12,7 +12,6 @@ require (
github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83 github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83
github.com/onsi/ginkgo v1.16.5 github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.27.6 github.com/onsi/gomega v1.27.6
github.com/pion/ice/v3 v3.0.2
github.com/rs/cors v1.8.0 github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
@@ -24,7 +23,7 @@ require (
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.73.0 google.golang.org/grpc v1.73.0
google.golang.org/protobuf v1.36.6 google.golang.org/protobuf v1.36.8
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
@@ -63,16 +62,18 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203
github.com/pion/logging v0.2.2 github.com/pion/ice/v4 v4.0.0-00010101000000-000000000000
github.com/pion/logging v0.2.4
github.com/pion/randutil v0.1.0 github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0 github.com/pion/stun/v2 v2.0.0
github.com/pion/transport/v3 v3.0.1 github.com/pion/stun/v3 v3.0.0
github.com/pion/transport/v3 v3.0.7
github.com/pion/turn/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.22.0 github.com/prometheus/client_golang v1.22.0
github.com/quic-go/quic-go v0.48.2 github.com/quic-go/quic-go v0.48.2
@@ -151,7 +152,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.0.0+incompatible // indirect github.com/docker/docker v26.1.5+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect
@@ -212,8 +213,10 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect github.com/pion/dtls/v3 v3.0.7 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pion/turn/v4 v4.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
@@ -229,6 +232,7 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect github.com/vishvananda/netns v0.0.4 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect
github.com/yuin/goldmark v1.7.1 // indirect github.com/yuin/goldmark v1.7.1 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
@@ -257,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944

35
go.sum
View File

@@ -167,8 +167,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.0.0+incompatible h1:Olh0KS820sJ7nPsBKChVhk5pzqcwDR15fumfAd/p9hM= github.com/docker/docker v26.1.5+incompatible h1:NEAxTwEjxV6VbBMBoGG3zPqbiJosIApZjxlbrG9q3/g=
github.com/docker/docker v28.0.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/docker v26.1.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
@@ -501,10 +501,10 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE
github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 h1:/ZbExdcDwRq6XgTpTf5I1DPqnC3eInEf0fcmkqR8eSg= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
@@ -546,21 +546,29 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v3 v3.0.1 h1:gDTlPJwROfSfz6QfSi0ZmeCSkFcnWWiiR9ES0ouANiM=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -681,6 +689,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -1167,11 +1177,10 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=

View File

@@ -328,6 +328,45 @@ delete_auto_service_user() {
echo "$PARSED_RESPONSE" echo "$PARSED_RESPONSE"
} }
delete_default_zitadel_admin() {
INSTANCE_URL=$1
PAT=$2
# Search for the default zitadel-admin user
RESPONSE=$(
curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \
-H "Authorization: Bearer $PAT" \
-H "Content-Type: application/json" \
-d '{
"queries": [
{
"userNameQuery": {
"userName": "zitadel-admin@",
"method": "TEXT_QUERY_METHOD_STARTS_WITH"
}
}
]
}'
)
DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty')
if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then
echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID"
RESPONSE=$(
curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \
-H "Authorization: Bearer $PAT" \
-H "Content-Type: application/json" \
)
PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"')
handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE"
else
echo "Default zitadel-admin user not found: $RESPONSE"
fi
}
init_zitadel() { init_zitadel() {
echo -e "\nInitializing Zitadel with NetBird's applications\n" echo -e "\nInitializing Zitadel with NetBird's applications\n"
INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
@@ -346,6 +385,9 @@ init_zitadel() {
echo -n "Waiting for Zitadel to become ready " echo -n "Waiting for Zitadel to become ready "
wait_api "$INSTANCE_URL" "$PAT" wait_api "$INSTANCE_URL" "$PAT"
echo "Deleting default zitadel-admin user..."
delete_default_zitadel_admin "$INSTANCE_URL" "$PAT"
# create the zitadel project # create the zitadel project
echo "Creating new zitadel project" echo "Creating new zitadel project"
PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT") PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT")

View File

@@ -111,3 +111,6 @@ Generate gRpc code:
#!/bin/bash #!/bin/bash
protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=. protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=.
``` ```

View File

@@ -20,7 +20,11 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore()) integratedPeerValidator, err := integrations.NewIntegratedValidator(
context.Background(),
s.PeersManager(),
s.SettingsManager(),
s.EventStore())
if err != nil { if err != nil {
log.Errorf("failed to create integrated peer validator: %v", err) log.Errorf("failed to create integrated peer validator: %v", err)
} }

View File

@@ -104,6 +104,8 @@ type DefaultAccountManager struct {
accountUpdateLocks sync.Map accountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64 updateAccountPeersBufferInterval atomic.Int64
loginFilter *loginFilter
disableDefaultPolicy bool disableDefaultPolicy bool
} }
@@ -211,6 +213,7 @@ func BuildManager(
proxyController: proxyController, proxyController: proxyController,
settingsManager: settingsManager, settingsManager: settingsManager,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy, disableDefaultPolicy: disableDefaultPolicy,
} }
@@ -1133,7 +1136,18 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
newUser := types.NewRegularUser(userAuth.UserId) newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, newUser)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID)
if err != nil {
return "", err
}
if settings != nil && settings.Extra != nil && settings.Extra.UserApprovalRequired {
newUser.Blocked = true
newUser.PendingApproval = true
}
err = am.Store.SaveUser(ctx, newUser)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -1143,7 +1157,11 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
return "", err return "", err
} }
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) if newUser.PendingApproval {
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
} else {
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
}
return domainAccountID, nil return domainAccountID, nil
} }
@@ -1612,6 +1630,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
} }
func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
return am.loginFilter.allowLogin(wgPubKey, metahash)
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
@@ -1628,6 +1650,9 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
} }
metahash := metaHash(meta, realIP.String())
am.loginFilter.addLogin(peerPubKey, metahash)
return peer, netMap, postureChecks, nil return peer, netMap, postureChecks, nil
} }
@@ -1636,7 +1661,6 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
} }
return nil return nil
} }
@@ -1690,7 +1714,9 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account
log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err) log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err)
continue continue
} }
peers = append(peers, peer) if peer.UserID != "" {
peers = append(peers, peer)
}
} }
if len(peers) > 0 { if len(peers) > 0 {
err := am.expireAndUpdatePeers(ctx, accountID, peers) err := am.expireAndUpdatePeers(ctx, accountID, peers)
@@ -1786,6 +1812,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
PeerInactivityExpirationEnabled: false, PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true, RoutingPeerDNSResolutionEnabled: true,
Extra: &types.ExtraSettings{
UserApprovalRequired: true,
},
}, },
Onboarding: types.AccountOnboarding{ Onboarding: types.AccountOnboarding{
OnboardingFlowPending: true, OnboardingFlowPending: true,
@@ -1892,6 +1921,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
PeerInactivityExpirationEnabled: false, PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true, RoutingPeerDNSResolutionEnabled: true,
Extra: &types.ExtraSettings{
UserApprovalRequired: true,
},
}, },
} }

View File

@@ -32,6 +32,8 @@ type Manager interface {
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error)
@@ -123,4 +125,5 @@ type Manager interface {
UpdateToPrimaryAccount(ctx context.Context, accountId string) error UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
AllowSync(string, uint64) bool
} }

View File

@@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -25,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -3046,19 +3048,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op") b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" { if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD maxExpected = bc.maxMsPerOpCICD
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "sync", "syncAndMark")
} }
if msPerOp < minExpected { if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
if msPerOp > (maxExpected * 1.1) {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
} }
}) })
} }
@@ -3121,19 +3118,14 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op") b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" { if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD maxExpected = bc.maxMsPerOpCICD
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "existingPeer")
} }
if msPerOp < minExpected { if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
if msPerOp > (maxExpected * 1.1) {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
} }
}) })
} }
@@ -3196,24 +3188,44 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op") b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" { if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD maxExpected = bc.maxMsPerOpCICD
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
} }
if msPerOp < minExpected { if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
if msPerOp > (maxExpected * 1.1) {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
} }
}) })
} }
} }
func TestMain(m *testing.M) {
exitCode := m.Run()
if exitCode == 0 && os.Getenv("CI") == "true" {
runID := os.Getenv("GITHUB_RUN_ID")
storeEngine := os.Getenv("NETBIRD_STORE_ENGINE")
err := push.New("http://localhost:9091", "account_manager_benchmark").
Collector(testing_tools.BenchmarkDuration).
Grouping("ci_run", runID).
Grouping("store_engine", storeEngine).
Push()
if err != nil {
log.Printf("Failed to push metrics: %v", err)
} else {
time.Sleep(1 * time.Minute)
_ = push.New("http://localhost:9091", "account_manager_benchmark").
Grouping("ci_run", runID).
Grouping("store_engine", storeEngine).
Delete()
}
}
os.Exit(exitCode)
}
func Test_GetCreateAccountByPrivateDomain(t *testing.T) { func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
@@ -3594,3 +3606,93 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
require.Error(t, err, "should fail with invalid peer ID") require.Error(t, err, "should fail with invalid peer ID")
}) })
} }
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create a domain-based account with user approval enabled
existingAccountID := "existing-account"
account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false)
account.Settings.Extra = &types.ExtraSettings{
UserApprovalRequired: true,
}
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Set the account as domain primary account
account.IsDomainPrimaryAccount = true
account.DomainCategory = types.PrivateCategory
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Test adding new user to existing account with approval required
newUserID := "new-user-id"
userAuth := nbcontext.UserAuth{
UserId: newUserID,
Domain: "example.com",
DomainCategory: types.PrivateCategory,
}
acc, err := manager.Store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
require.True(t, acc.IsDomainPrimaryAccount, "Account should be primary for the domain")
require.Equal(t, "example.com", acc.Domain, "Account domain should match")
returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth)
require.NoError(t, err)
require.Equal(t, existingAccountID, returnedAccountID)
// Verify user was created with pending approval
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID)
require.NoError(t, err)
assert.True(t, user.Blocked, "User should be blocked when approval is required")
assert.True(t, user.PendingApproval, "User should be pending approval")
assert.Equal(t, existingAccountID, user.AccountID)
}
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create a domain-based account without user approval
ownerUserAuth := nbcontext.UserAuth{
UserId: "owner-user",
Domain: "example.com",
DomainCategory: types.PrivateCategory,
}
existingAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), ownerUserAuth)
require.NoError(t, err)
// Modify the account to disable user approval
account, err := manager.Store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
account.Settings.Extra = &types.ExtraSettings{
UserApprovalRequired: false,
}
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Test adding new user to existing account without approval required
newUserID := "new-user-id"
userAuth := nbcontext.UserAuth{
UserId: newUserID,
Domain: "example.com",
DomainCategory: types.PrivateCategory,
}
returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth)
require.NoError(t, err)
require.Equal(t, existingAccountID, returnedAccountID)
// Verify user was created without pending approval
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID)
require.NoError(t, err)
assert.False(t, user.Blocked, "User should not be blocked when approval is not required")
assert.False(t, user.PendingApproval, "User should not be pending approval")
assert.Equal(t, existingAccountID, user.AccountID)
}

View File

@@ -177,6 +177,8 @@ const (
AccountNetworkRangeUpdated Activity = 87 AccountNetworkRangeUpdated Activity = 87
PeerIPUpdated Activity = 88 PeerIPUpdated Activity = 88
UserApproved Activity = 89
UserRejected Activity = 90
AccountDeleted Activity = 99999 AccountDeleted Activity = 99999
) )
@@ -284,6 +286,8 @@ var activityMap = map[Activity]Code{
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"}, AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
UserApproved: {"User approved", "user.approve"},
UserRejected: {"User rejected", "user.reject"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -202,35 +202,45 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
} }
var eventsToStore []func() var eventsToStore []func()
var groupsToSave []*types.Group
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var globalErr error
groupIDs := make([]string, 0, len(groups)) groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups { for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
newGroup.AccountID = accountID newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
return err
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return err
}
groupIDs = append(groupIDs, newGroup.ID) groupIDs = append(groupIDs, newGroup.ID)
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) return nil
})
if err != nil { if err != nil {
return err log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
if len(groupIDs) == 1 {
return err
}
globalErr = errors.Join(globalErr, err)
// continue updating other groups
} }
}
if err = transaction.CreateGroups(ctx, accountID, groupsToSave); err != nil { updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil { if err != nil {
return err return err
} }
@@ -243,7 +253,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
return nil return globalErr
} }
// UpdateGroups updates groups in the account. // UpdateGroups updates groups in the account.
@@ -260,35 +270,45 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
} }
var eventsToStore []func() var eventsToStore []func()
var groupsToSave []*types.Group
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var globalErr error
groupIDs := make([]string, 0, len(groups)) groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups { for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
newGroup.AccountID = accountID newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID) if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return err
}
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) groupIDs = append(groupIDs, newGroup.ID)
return nil
})
if err != nil { if err != nil {
return err log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
if len(groups) == 1 {
return err
}
globalErr = errors.Join(globalErr, err)
// continue updating other groups
} }
}
if err = transaction.UpdateGroups(ctx, accountID, groupsToSave); err != nil { updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil { if err != nil {
return err return err
} }
@@ -301,7 +321,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
return nil return globalErr
} }
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
@@ -584,13 +604,6 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
newGroup.ID = xid.New().String() newGroup.ID = xid.New().String()
} }
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil return nil
} }

View File

@@ -2,9 +2,11 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -38,20 +40,28 @@ import (
internalStatus "github.com/netbirdio/netbird/shared/management/status" internalStatus "github.com/netbirdio/netbird/shared/management/status"
) )
const (
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
envBlockPeers = "NB_BLOCK_SAME_PEERS"
)
// GRPCServer an instance of a Management gRPC API server // GRPCServer an instance of a Management gRPC API server
type GRPCServer struct { type GRPCServer struct {
accountManager account.Manager accountManager account.Manager
settingsManager settings.Manager settingsManager settings.Manager
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *nbconfig.Config config *nbconfig.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
peerLocks sync.Map peerLocks sync.Map
authManager auth.Manager authManager auth.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
logBlockedPeers bool
blockPeersWithSameConfig bool
integratedPeerValidator integrated_validator.IntegratedValidator
} }
// NewServer creates a new Management server // NewServer creates a new Management server
@@ -82,18 +92,23 @@ func NewServer(
} }
} }
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
accountManager: accountManager, accountManager: accountManager,
settingsManager: settingsManager, settingsManager: settingsManager,
config: config, config: config,
secretsManager: secretsManager, secretsManager: secretsManager,
authManager: authManager, authManager: authManager,
appMetrics: appMetrics, appMetrics: appMetrics,
ephemeralManager: ephemeralManager, ephemeralManager: ephemeralManager,
integratedPeerValidator: integratedPeerValidator, logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
}, nil }, nil
} }
@@ -136,9 +151,6 @@ func getRealIP(ctx context.Context) net.IP {
// notifies the connected peer of any updates (e.g. new peers under the same account) // notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
reqStart := time.Now() reqStart := time.Now()
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
ctx := srv.Context() ctx := srv.Context()
@@ -147,6 +159,25 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if err != nil { if err != nil {
return err return err
} }
realIP := getRealIP(ctx)
sRealIP := realIP.String()
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
}
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
}
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
@@ -172,14 +203,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
realIP := getRealIP(ctx) log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
if syncReq.GetMeta() == nil { if syncReq.GetMeta() == nil {
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
} }
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
return mapError(ctx, err) return mapError(ctx, err)
@@ -345,6 +375,9 @@ func mapError(ctx context.Context, err error) error {
default: default:
} }
} }
if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) {
return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error())
}
log.WithContext(ctx).Errorf("got an unhandled error: %s", err) log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request") return status.Errorf(codes.Internal, "failed handling request")
} }
@@ -436,12 +469,9 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
// In case of the successful registration login is also successful // In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
reqStart := time.Now() reqStart := time.Now()
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) sRealIP := realIP.String()
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
loginReq := &proto.LoginRequest{} loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq) peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -449,6 +479,24 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, err return nil, err
} }
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
}
if s.blockPeersWithSameConfig {
return nil, internalStatus.ErrPeerAlreadyLoggedIn
}
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
//nolint //nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
@@ -485,7 +533,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{ peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
WireGuardPubKey: peerKey.String(), WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey), SSHKey: string(sshKey),
Meta: extractPeerMeta(ctx, loginReq.GetMeta()), Meta: peerMeta,
UserID: userID, UserID: userID,
SetupKey: loginReq.GetSetupKey(), SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP, ConnectionIP: realIP,
@@ -951,8 +999,6 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
return nil, mapError(ctx, err) return nil, mapError(ctx, err)
} }
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start)) log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
return &proto.Empty{}, nil return &proto.Empty{}, nil

View File

@@ -11,11 +11,11 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
) )
const ( const (
@@ -198,6 +198,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
if req.Settings.Extra != nil { if req.Settings.Extra != nil {
settings.Extra = &types.ExtraSettings{ settings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
@@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
if settings.Extra != nil { if settings.Extra != nil {
apiSettings.Extra = &api.AccountExtraSettings{ apiSettings.Extra = &api.AccountExtraSettings{
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
UserApprovalRequired: settings.Extra.UserApprovalRequired,
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled, NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
NetworkTrafficLogsGroups: settings.Extra.FlowGroups, NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled, NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,

View File

@@ -15,11 +15,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
) )
func initAccountsTestData(t *testing.T, account *types.Account) *handler { func initAccountsTestData(t *testing.T, account *types.Account) *handler {

View File

@@ -14,11 +14,11 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
) )
// Handler is a handler that returns peers of the account // Handler is a handler that returns peers of the account
@@ -354,7 +354,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
} }
return &api.Peer{ return &api.Peer{
CreatedAt: peer.CreatedAt, CreatedAt: peer.CreatedAt,
Id: peer.ID, Id: peer.ID,
Name: peer.Name, Name: peer.Name,
Ip: peer.IP.String(), Ip: peer.IP.String(),
@@ -391,33 +391,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
} }
return &api.PeerBatch{ return &api.PeerBatch{
CreatedAt: peer.CreatedAt, CreatedAt: peer.CreatedAt,
Id: peer.ID, Id: peer.ID,
Name: peer.Name, Name: peer.Name,
Ip: peer.IP.String(), Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(), ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected, Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen, LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion, KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID), GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion, Version: peer.Meta.WtVersion,
Groups: groupsInfo, Groups: groupsInfo,
SshEnabled: peer.SSHEnabled, SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname, Hostname: peer.Meta.Hostname,
UserId: peer.UserID, UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion, UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain), DnsLabel: fqdn(peer, dnsDomain),
ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled, LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.GetLastLogin(), LastLogin: peer.GetLastLogin(),
LoginExpired: peer.Status.LoginExpired, LoginExpired: peer.Status.LoginExpired,
AccessiblePeersCount: accessiblePeersCount, AccessiblePeersCount: accessiblePeersCount,
CountryCode: peer.Location.CountryCode, CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName, CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber, SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled, InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
} }
} }

View File

@@ -9,11 +9,11 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
) )
@@ -31,6 +31,8 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
addUsersTokensEndpoint(accountManager, router) addUsersTokensEndpoint(accountManager, router)
} }
@@ -323,17 +325,76 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
} }
isCurrent := user.ID == currenUserID isCurrent := user.ID == currenUserID
return &api.User{ return &api.User{
Id: user.ID, Id: user.ID,
Name: user.Name, Name: user.Name,
Email: user.Email, Email: user.Email,
Role: user.Role, Role: user.Role,
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: userStatus, Status: userStatus,
IsCurrent: &isCurrent, IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser, IsServiceUser: &user.IsServiceUser,
IsBlocked: user.IsBlocked, IsBlocked: user.IsBlocked,
LastLogin: &user.LastLogin, LastLogin: &user.LastLogin,
Issued: &user.Issued, Issued: &user.Issued,
PendingApproval: user.PendingApproval,
} }
} }
// approveUser is a POST request to approve a user that is pending approval
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
userResponse := toUserResponse(user, userAuth.UserId)
util.WriteJSONObject(r.Context(), w, userResponse)
}
// rejectUser is a DELETE request to reject a user that is pending approval
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -16,13 +16,13 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -725,3 +725,133 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri
} }
return modules return modules
} }
func TestApproveUserEndpoint(t *testing.T) {
adminUser := &types.User{
Id: "admin-user",
Role: types.UserRoleAdmin,
AccountID: existingAccountID,
AutoGroups: []string{},
}
pendingUser := &types.User{
Id: "pending-user",
Role: types.UserRoleUser,
AccountID: existingAccountID,
Blocked: true,
PendingApproval: true,
AutoGroups: []string{},
}
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestingUser *types.User
}{
{
name: "approve user as admin should return 200",
expectedStatus: 200,
expectedBody: true,
requestingUser: adminUser,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{}
am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
approvedUserInfo := &types.UserInfo{
ID: pendingUser.Id,
Email: "pending@example.com",
Name: "Pending User",
Role: string(pendingUser.Role),
AutoGroups: []string{},
IsServiceUser: false,
IsBlocked: false,
PendingApproval: false,
LastLogin: time.Now(),
Issued: types.UserIssuedAPI,
}
return approvedUserInfo, nil
}
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)
userAuth := nbcontext.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedBody {
var response api.User
err = json.Unmarshal(rr.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "pending-user", response.Id)
assert.False(t, response.IsBlocked)
assert.False(t, response.PendingApproval)
}
})
}
}
func TestRejectUserEndpoint(t *testing.T) {
adminUser := &types.User{
Id: "admin-user",
Role: types.UserRoleAdmin,
AccountID: existingAccountID,
AutoGroups: []string{},
}
tt := []struct {
name string
expectedStatus int
requestingUser *types.User
}{
{
name: "reject user as admin should return 200",
expectedStatus: 200,
requestingUser: adminUser,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{}
am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
return nil
}
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err)
userAuth := nbcontext.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
})
}
}

View File

@@ -17,8 +17,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
) )
const modulePeers = "peers" const modulePeers = "peers"
@@ -47,7 +48,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
for name, bc := range benchCasesPeers { for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer() b.ResetTimer()
@@ -65,7 +66,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
}) })
} }
} }
@@ -82,7 +83,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
for name, bc := range benchCasesPeers { for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer() b.ResetTimer()
@@ -92,7 +93,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
}) })
} }
} }
@@ -109,7 +110,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
for name, bc := range benchCasesPeers { for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer() b.ResetTimer()
@@ -119,7 +120,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
}) })
} }
} }
@@ -136,7 +137,7 @@ func BenchmarkDeletePeer(b *testing.B) {
for name, bc := range benchCasesPeers { for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer() b.ResetTimer()
@@ -146,7 +147,7 @@ func BenchmarkDeletePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
}) })
} }
} }

View File

@@ -17,8 +17,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
) )
// Map to store peers, groups, users, and setupKeys by name // Map to store peers, groups, users, and setupKeys by name
@@ -47,7 +48,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys { for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer() b.ResetTimer()
@@ -69,7 +70,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
}) })
} }
} }
@@ -86,7 +87,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys { for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer() b.ResetTimer()
@@ -109,7 +110,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
}) })
} }
} }
@@ -126,7 +127,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys { for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer() b.ResetTimer()
@@ -136,7 +137,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
}) })
} }
} }
@@ -153,7 +154,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
for name, bc := range benchCasesSetupKeys { for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer() b.ResetTimer()
@@ -163,7 +164,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
}) })
} }
} }
@@ -180,7 +181,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys { for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000)
b.ResetTimer() b.ResetTimer()
@@ -190,7 +191,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
}) })
} }
} }

View File

@@ -18,8 +18,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
) )
const moduleUsers = "users" const moduleUsers = "users"
@@ -46,7 +47,7 @@ func BenchmarkUpdateUser(b *testing.B) {
for name, bc := range benchCasesUsers { for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -71,7 +72,7 @@ func BenchmarkUpdateUser(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
}) })
} }
} }
@@ -84,18 +85,18 @@ func BenchmarkGetOneUser(b *testing.B) {
for name, bc := range benchCasesUsers { for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
b.ResetTimer() b.ResetTimer()
start := time.Now() start := time.Now()
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
}) })
} }
} }
@@ -110,18 +111,18 @@ func BenchmarkGetAllUsers(b *testing.B) {
for name, bc := range benchCasesUsers { for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
b.ResetTimer() b.ResetTimer()
start := time.Now() start := time.Now()
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
}) })
} }
} }
@@ -136,7 +137,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
for name, bc := range benchCasesUsers { for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -147,7 +148,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
apiHandler.ServeHTTP(recorder, req) apiHandler.ServeHTTP(recorder, req)
} }
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
}) })
} }
} }

View File

@@ -15,9 +15,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
) )
func Test_SetupKeys_Create(t *testing.T) { func Test_SetupKeys_Create(t *testing.T) {
@@ -287,7 +288,7 @@ func Test_SetupKeys_Create(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
for _, user := range users { for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) { t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
body, err := json.Marshal(tc.requestBody) body, err := json.Marshal(tc.requestBody)
if err != nil { if err != nil {
@@ -572,7 +573,7 @@ func Test_SetupKeys_Update(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
for _, user := range users { for _, user := range users {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
body, err := json.Marshal(tc.requestBody) body, err := json.Marshal(tc.requestBody)
if err != nil { if err != nil {
@@ -751,7 +752,7 @@ func Test_SetupKeys_Get(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
for _, user := range users { for _, user := range users {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)
@@ -903,7 +904,7 @@ func Test_SetupKeys_GetAll(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
for _, user := range users { for _, user := range users {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId)
@@ -1087,7 +1088,7 @@ func Test_SetupKeys_Delete(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
for _, user := range users { for _, user := range users {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)

View File

@@ -0,0 +1,137 @@
package channel
import (
"context"
"errors"
"net/http"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/users"
)
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
}
t.Cleanup(cleanup)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
t.Fatalf("Failed to create metrics: %v", err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
done := make(chan struct{})
if validateUpdate {
go func() {
if expectedPeerUpdate != nil {
peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
} else {
peerShouldNotReceiveUpdate(t, updMsg)
}
close(done)
}()
}
geoMock := &geolocation.Mock{}
validatorMock := server.MockIntegratedValidator{}
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
MarkPATUsedFunc: authManager.MarkPATUsed,
GetPATInfoFunc: authManager.GetPATInfo,
}
networksManagerMock := networks.NewManagerMock()
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
return apiHandler, am, done
}
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
t.Errorf("Unexpected message received: %+v", msg)
case <-time.After(500 * time.Millisecond):
return
}
}
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
assert.Equal(t, expected, msg)
case <-time.After(500 * time.Millisecond):
t.Errorf("Timed out waiting for update message")
}
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
userAuth := nbcontext.UserAuth{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
userAuth.UserId = token
userAuth.AccountId = "testAccountId"
userAuth.Domain = "test.com"
userAuth.DomainCategory = "private"
case "otherUserId":
userAuth.UserId = "otherUserId"
userAuth.AccountId = "otherAccountId"
userAuth.Domain = "other.com"
userAuth.DomainCategory = "private"
case "invalidToken":
return userAuth, nil, errors.New("invalid token")
}
jwtToken := jwt.New(jwt.SigningMethodHS256)
return userAuth, jwtToken, nil
}

View File

@@ -3,7 +3,6 @@ package testing_tools
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -14,32 +13,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt/v5"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
) )
@@ -106,90 +85,6 @@ type PerformanceMetrics struct {
MaxMsPerOpCICD float64 MaxMsPerOpCICD float64
} }
func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
}
t.Cleanup(cleanup)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
t.Fatalf("Failed to create metrics: %v", err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId)
done := make(chan struct{})
if validateUpdate {
go func() {
if expectedPeerUpdate != nil {
peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
} else {
peerShouldNotReceiveUpdate(t, updMsg)
}
close(done)
}()
}
geoMock := &geolocation.Mock{}
validatorMock := server.MockIntegratedValidator{}
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
MarkPATUsedFunc: authManager.MarkPATUsed,
GetPATInfoFunc: authManager.GetPATInfo,
}
networksManagerMock := networks.NewManagerMock()
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
return apiHandler, am, done
}
func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
t.Errorf("Unexpected message received: %+v", msg)
case <-time.After(500 * time.Millisecond):
return
}
}
func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
assert.Equal(t, expected, msg)
case <-time.After(500 * time.Millisecond):
t.Errorf("Timed out waiting for update message")
}
}
func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request { func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request {
t.Helper() t.Helper()
@@ -222,11 +117,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta
return content, expectedStatus == http.StatusOK return content, expectedStatus == http.StatusOK
} }
func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) {
b.Helper() b.Helper()
ctx := context.Background() ctx := context.Background()
account, err := am.GetAccount(ctx, TestAccountId) acc, err := am.GetAccount(ctx, TestAccountId)
if err != nil { if err != nil {
b.Fatalf("Failed to get account: %v", err) b.Fatalf("Failed to get account: %v", err)
} }
@@ -242,23 +137,23 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true},
UserID: TestUserId, UserID: TestUserId,
} }
account.Peers[peer.ID] = peer acc.Peers[peer.ID] = peer
} }
// Create users // Create users
for i := 0; i < users; i++ { for i := 0; i < users; i++ {
user := &types.User{ user := &types.User{
Id: fmt.Sprintf("olduser-%d", i), Id: fmt.Sprintf("olduser-%d", i),
AccountID: account.Id, AccountID: acc.Id,
Role: types.UserRoleUser, Role: types.UserRoleUser,
} }
account.Users[user.Id] = user acc.Users[user.Id] = user
} }
for i := 0; i < setupKeys; i++ { for i := 0; i < setupKeys; i++ {
key := &types.SetupKey{ key := &types.SetupKey{
Id: fmt.Sprintf("oldkey-%d", i), Id: fmt.Sprintf("oldkey-%d", i),
AccountID: account.Id, AccountID: acc.Id,
AutoGroups: []string{"someGroupID"}, AutoGroups: []string{"someGroupID"},
UpdatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)), ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)),
@@ -266,11 +161,11 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
Type: "reusable", Type: "reusable",
UsageLimit: 0, UsageLimit: 0,
} }
account.SetupKeys[key.Id] = key acc.SetupKeys[key.Id] = key
} }
// Create groups and policies // Create groups and policies
account.Policies = make([]*types.Policy, 0, groups) acc.Policies = make([]*types.Policy, 0, groups)
for i := 0; i < groups; i++ { for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i) groupID := fmt.Sprintf("group-%d", i)
group := &types.Group{ group := &types.Group{
@@ -281,7 +176,7 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
peerIndex := i*(peers/groups) + j peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
} }
account.Groups[groupID] = group acc.Groups[groupID] = group
// Create a policy for this group // Create a policy for this group
policy := &types.Policy{ policy := &types.Policy{
@@ -301,10 +196,10 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
}, },
}, },
} }
account.Policies = append(account.Policies, policy) acc.Policies = append(acc.Policies, policy)
} }
account.PostureChecks = []*posture.Checks{ acc.PostureChecks = []*posture.Checks{
{ {
ID: "PostureChecksAll", ID: "PostureChecksAll",
Name: "All", Name: "All",
@@ -316,52 +211,38 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
}, },
} }
err = am.Store.SaveAccount(context.Background(), account) store := am.GetStore()
err = store.SaveAccount(context.Background(), acc)
if err != nil { if err != nil {
b.Fatalf("Failed to save account: %v", err) b.Fatalf("Failed to save account: %v", err)
} }
} }
func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
b.Helper() b.Helper()
branch := os.Getenv("GIT_BRANCH")
if branch == "" {
b.Fatalf("environment variable GIT_BRANCH is not set")
}
if recorder.Code != http.StatusOK { if recorder.Code != http.StatusOK {
b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code) b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code)
} }
EvaluateBenchmarkResults(b, testCase, duration, module, operation)
}
func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) {
b.Helper()
branch := os.Getenv("GIT_BRANCH")
if branch == "" && os.Getenv("CI") == "true" {
b.Fatalf("environment variable GIT_BRANCH is not set")
}
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch) gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch)
gauge.Set(msPerOp) gauge.Set(msPerOp)
b.ReportMetric(msPerOp, "ms/op") b.ReportMetric(msPerOp, "ms/op")
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
userAuth := nbcontext.UserAuth{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
userAuth.UserId = token
userAuth.AccountId = "testAccountId"
userAuth.Domain = "test.com"
userAuth.DomainCategory = "private"
case "otherUserId":
userAuth.UserId = "otherUserId"
userAuth.AccountId = "otherAccountId"
userAuth.Domain = "other.com"
userAuth.DomainCategory = "private"
case "invalidToken":
return userAuth, nil, errors.New("invalid token")
}
jwtToken := jwt.New(jwt.SigningMethodHS256)
return userAuth, jwtToken, nil
} }

View File

@@ -0,0 +1,160 @@
package server
import (
"hash/fnv"
"math"
"sync"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
const (
reconnThreshold = 5 * time.Minute
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
)
type lfConfig struct {
reconnThreshold time.Duration
baseBlockDuration time.Duration
reconnLimitForBan int
metaChangeLimit int
}
func initCfg() *lfConfig {
return &lfConfig{
reconnThreshold: reconnThreshold,
baseBlockDuration: baseBlockDuration,
reconnLimitForBan: reconnLimitForBan,
metaChangeLimit: metaChangeLimit,
}
}
type loginFilter struct {
mu sync.RWMutex
cfg *lfConfig
logged map[string]*peerState
}
type peerState struct {
currentHash uint64
sessionCounter int
sessionStart time.Time
lastSeen time.Time
isBanned bool
banLevel int
banExpiresAt time.Time
metaChangeCounter int
metaChangeWindowStart time.Time
}
func newLoginFilter() *loginFilter {
return newLoginFilterWithCfg(initCfg())
}
func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter {
return &loginFilter{
logged: make(map[string]*peerState),
cfg: cfg,
}
}
func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool {
l.mu.RLock()
defer func() {
l.mu.RUnlock()
}()
state, ok := l.logged[wgPubKey]
if !ok {
return true
}
if state.isBanned && time.Now().Before(state.banExpiresAt) {
return false
}
if metaHash != state.currentHash {
if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit {
return false
}
}
return true
}
func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
now := time.Now()
l.mu.Lock()
defer func() {
l.mu.Unlock()
}()
state, ok := l.logged[wgPubKey]
if !ok {
l.logged[wgPubKey] = &peerState{
currentHash: metaHash,
sessionCounter: 1,
sessionStart: now,
lastSeen: now,
metaChangeWindowStart: now,
metaChangeCounter: 1,
}
return
}
if state.isBanned && now.After(state.banExpiresAt) {
state.isBanned = false
}
if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) {
state.banLevel = 0
}
if metaHash != state.currentHash {
if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) {
state.metaChangeWindowStart = now
state.metaChangeCounter = 1
} else {
state.metaChangeCounter++
}
state.currentHash = metaHash
state.sessionCounter = 1
state.sessionStart = now
state.lastSeen = now
return
}
state.sessionCounter++
if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold {
state.isBanned = true
state.banLevel++
backoffFactor := math.Pow(2, float64(state.banLevel-1))
duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor)
state.banExpiresAt = now.Add(duration)
state.sessionCounter = 0
state.sessionStart = now
}
state.lastSeen = now
}
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
h := fnv.New64a()
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
macs := uint64(0)
for _, na := range meta.NetworkAddresses {
for _, r := range na.Mac {
macs += uint64(r)
}
}
return h.Sum64() + macs
}

View File

@@ -0,0 +1,275 @@
package server
import (
"hash/fnv"
"math"
"math/rand"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/suite"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func testAdvancedCfg() *lfConfig {
return &lfConfig{
reconnThreshold: 50 * time.Millisecond,
baseBlockDuration: 100 * time.Millisecond,
reconnLimitForBan: 3,
metaChangeLimit: 2,
}
}
type LoginFilterTestSuite struct {
suite.Suite
filter *loginFilter
}
func (s *LoginFilterTestSuite) SetupTest() {
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
}
func TestLoginFilterTestSuite(t *testing.T) {
suite.Run(t, new(LoginFilterTestSuite))
}
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
}
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.False(s.filter.allowLogin(pubKey, meta))
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
baseBan := s.filter.cfg.baseBlockDuration
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(1, s.filter.logged[pubKey].banLevel)
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
s.filter.logged[pubKey].isBanned = false
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(2, s.filter.logged[pubKey].banLevel)
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
}
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
isBanned: true,
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
}
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.False(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
currentHash: meta,
banLevel: 3,
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
}
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(0, s.filter.logged[pubKey].banLevel)
}
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
pubKey := "PUB_KEY_A"
limit := s.filter.cfg.metaChangeLimit
for i := range limit {
s.filter.addLogin(pubKey, uint64(i+1))
}
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
s.False(isAllowed, "should block new meta hash after limit is reached")
}
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
pubKey := "PUB_KEY_A"
meta1 := uint64(1)
meta2 := uint64(2)
meta3 := uint64(3)
s.filter.addLogin(pubKey, meta1)
s.filter.addLogin(pubKey, meta2)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
s.filter.addLogin(pubKey, meta3)
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
}
func BenchmarkHashingMethods(b *testing.B) {
meta := nbpeer.PeerSystemMeta{
WtVersion: "1.25.1",
OSVersion: "Ubuntu 22.04.3 LTS",
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
b.Run("BuilderString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta, pubip)
}
})
b.Run("FnvHashToString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta, pubip)
}
})
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta, pubip)
}
})
_ = resultString
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
b.WriteString(meta.WtVersion)
b.WriteByte('|')
b.WriteString(meta.OSVersion)
b.WriteByte('|')
b.WriteString(meta.KernelVersion)
b.WriteByte('|')
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000
pubKeys := make([]string, numKeys)
for i := range numKeys {
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for pb.Next() {
key := pubKeys[r.Intn(numKeys)]
meta := r.Uint64()
if filter.allowLogin(key, meta) {
filter.addLogin(key, meta)
}
}
})
}

View File

@@ -95,6 +95,8 @@ type MockAccountManager struct {
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error) GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() account.ExternalCacheManager GetExternalCacheManagerFunc func() account.ExternalCacheManager
@@ -121,8 +123,10 @@ type MockAccountManager struct {
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
} }
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -605,6 +609,20 @@ func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string,
return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented") return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented")
} }
func (am *MockAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
if am.ApproveUserFunc != nil {
return am.ApproveUserFunc(ctx, accountID, initiatorUserID, targetUserID)
}
return nil, status.Errorf(codes.Unimplemented, "method ApproveUser is not implemented")
}
func (am *MockAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
if am.RejectUserFunc != nil {
return am.RejectUserFunc(ctx, accountID, initiatorUserID, targetUserID)
}
return status.Errorf(codes.Unimplemented, "method RejectUser is not implemented")
}
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if am.GetNameServerGroupFunc != nil { if am.GetNameServerGroupFunc != nil {
@@ -953,3 +971,10 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n
} }
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
} }
func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
if am.AllowSyncFunc != nil {
return am.AllowSyncFunc(key, hash)
}
return true
}

View File

@@ -368,10 +368,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err return err
} }
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
return fmt.Errorf("failed to remove peer from groups: %w", err)
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil { if err != nil {
return fmt.Errorf("failed to delete peer: %w", err) return fmt.Errorf("failed to delete peer: %w", err)
@@ -493,6 +489,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if err != nil { if err != nil {
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found") return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found")
} }
if user.PendingApproval {
return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
}
groupsToAdd = user.AutoGroups groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser opEvent.Activity = activity.PeerAddedByUser

View File

@@ -26,6 +26,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
@@ -989,19 +990,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op") b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" { if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD maxExpected = bc.maxMsPerOpCICD
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
} }
if msPerOp < minExpected { if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
if msPerOp > (maxExpected * 1.1) {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
} }
}) })
} }
@@ -1609,7 +1605,6 @@ func Test_LoginPeer(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
setupKey string setupKey string
wireGuardPubKey string
expectExtraDNSLabelsMismatch bool expectExtraDNSLabelsMismatch bool
extraDNSLabels []string extraDNSLabels []string
expectLoginError bool expectLoginError bool
@@ -2388,3 +2383,186 @@ func TestBufferUpdateAccountPeers(t *testing.T) {
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
} }
func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Try to add peer with pending approval user
key, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer := &nbpeer.Peer{
Key: key.PublicKey().String(),
Name: "test-peer",
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
},
}
_, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer)
require.Error(t, err)
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
}
func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create regular user (not pending approval)
regularUser := types.NewRegularUser("regular-user")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
// Try to add peer with regular user
key, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer := &nbpeer.Peer{
Key: key.PublicKey().String(),
Name: "test-peer",
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
},
}
_, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer)
require.NoError(t, err, "Regular user should be able to add peers")
}
func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Create a peer using AddPeer method for the pending user (simulate existing peer)
key, err := wgtypes.GenerateKey()
require.NoError(t, err)
// Set the user to not be pending initially so peer can be added
pendingUser.Blocked = false
pendingUser.PendingApproval = false
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Add peer using regular flow
newPeer := &nbpeer.Peer{
Key: key.PublicKey().String(),
Name: "test-peer",
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
WtVersion: "0.28.0",
},
}
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer)
require.NoError(t, err)
// Now set the user back to pending approval after peer was created
pendingUser.Blocked = true
pendingUser.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Try to login with pending approval user
login := types.PeerLogin{
WireGuardPubKey: existingPeer.Key,
UserID: pendingUser.Id,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
},
}
_, _, _, err = manager.LoginPeer(context.Background(), login)
require.Error(t, err)
e, ok := status.FromError(err)
require.True(t, ok, "error is not a gRPC status error")
assert.Equal(t, status.PermissionDenied, e.Type(), "expected PermissionDenied error code")
}
func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create regular user (not pending approval)
regularUser := types.NewRegularUser("regular-user")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
// Add peer using regular flow for the regular user
key, err := wgtypes.GenerateKey()
require.NoError(t, err)
newPeer := &nbpeer.Peer{
Key: key.PublicKey().String(),
Name: "test-peer",
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
WtVersion: "0.28.0",
},
}
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer)
require.NoError(t, err)
// Try to login with regular user
login := types.PeerLogin{
WireGuardPubKey: existingPeer.Key,
UserID: regularUser.Id,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
},
}
_, _, _, err = manager.LoginPeer(context.Background(), login)
require.NoError(t, err, "Regular user should be able to login peers")
}

Some files were not shown because too many files have changed in this diff Show More