diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index 0013833c4..ba36c013b 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -217,7 +217,7 @@ jobs:
- arch: "386"
raceFlag: ""
- arch: "amd64"
- raceFlag: ""
+ raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -382,6 +382,32 @@ jobs:
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
+ - name: Create Docker network
+ run: docker network create promnet
+
+ - name: Start Prometheus Pushgateway
+ run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
+
+ - name: Start Prometheus (for Pushgateway forwarding)
+ run: |
+ echo '
+ global:
+ scrape_interval: 15s
+ scrape_configs:
+ - job_name: "pushgateway"
+ static_configs:
+ - targets: ["pushgateway:9091"]
+ remote_write:
+ - url: ${{ secrets.GRAFANA_URL }}
+ basic_auth:
+ username: ${{ secrets.GRAFANA_USER }}
+ password: ${{ secrets.GRAFANA_API_KEY }}
+ ' > prometheus.yml
+
+ docker run -d --name prometheus --network promnet \
+ -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
+ -p 9090:9090 \
+ prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
@@ -428,9 +454,10 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
+ GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \
- -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
- -timeout 20m ./management/... ./shared/management/...
+ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
+ -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
api_benchmark:
name: "Management / Benchmark (API)"
@@ -521,7 +548,7 @@ jobs:
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
- -timeout 20m ./management/... ./shared/management/...
+ -timeout 20m ./management/server/http/...
api_integration_test:
name: "Management / Integration"
@@ -571,4 +598,4 @@ jobs:
CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
- -timeout 20m ./management/... ./shared/management/...
\ No newline at end of file
+ -timeout 20m ./management/server/http/...
diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml
index d9ff0a84b..2083c0721 100644
--- a/.github/workflows/golang-test-windows.yml
+++ b/.github/workflows/golang-test-windows.yml
@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
+ - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 7be52259b..e9741f541 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -9,7 +9,7 @@ on:
pull_request:
env:
- SIGN_PIPE_VER: "v0.0.22"
+ SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
diff --git a/README.md b/README.md
index ea7655869..2c5ee2ab6 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,4 @@
+
@@ -52,7 +53,7 @@
### Open Source Network Security in a Single Platform
-

+https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
diff --git a/client/Dockerfile b/client/Dockerfile
index e19a09909..b2f627409 100644
--- a/client/Dockerfile
+++ b/client/Dockerfile
@@ -18,7 +18,7 @@ ENV \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
- NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
+ NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
diff --git a/client/android/client.go b/client/android/client.go
index c05246569..d2d0c37f6 100644
--- a/client/android/client.go
+++ b/client/android/client.go
@@ -4,6 +4,7 @@ package android
import (
"context"
+ "os"
"slices"
"sync"
@@ -18,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
- "github.com/netbirdio/netbird/util/net"
+ "github.com/netbirdio/netbird/client/net"
)
// ConnectionListener export internal Listener for mobile
@@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
}
// Run start the internal client. It is a blocker function
-func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
+func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
+ exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
-func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
+func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
+ exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
+
+func exportEnvList(list *EnvList) {
+ if list == nil {
+ return
+ }
+ for k, v := range list.AllItems() {
+ if err := os.Setenv(k, v); err != nil {
+ log.Errorf("could not set env variable %s: %v", k, err)
+ }
+ }
+}
diff --git a/client/android/env_list.go b/client/android/env_list.go
new file mode 100644
index 000000000..04122300a
--- /dev/null
+++ b/client/android/env_list.go
@@ -0,0 +1,32 @@
+package android
+
+import "github.com/netbirdio/netbird/client/internal/peer"
+
+var (
+ // EnvKeyNBForceRelay Exported for Android java client
+ EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
+)
+
+// EnvList wraps a Go map for export to Java
+type EnvList struct {
+ data map[string]string
+}
+
+// NewEnvList creates a new EnvList
+func NewEnvList() *EnvList {
+ return &EnvList{data: make(map[string]string)}
+}
+
+// Put adds a key-value pair
+func (el *EnvList) Put(key, value string) {
+ el.data[key] = value
+}
+
+// Get retrieves a value by key
+func (el *EnvList) Get(key string) string {
+ return el.data[key]
+}
+
+func (el *EnvList) AllItems() map[string]string {
+ return el.data
+}
diff --git a/client/android/login.go b/client/android/login.go
index d8ac645e2..0df78dbc3 100644
--- a/client/android/login.go
+++ b/client/android/login.go
@@ -33,6 +33,7 @@ type ErrListener interface {
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
+ OnLoginSuccess()
}
// Auth can register or login new client
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
+
+ if err == nil {
+ go urlOpener.OnLoginSuccess()
+ }
+
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
diff --git a/client/cmd/down.go b/client/cmd/down.go
index 3ce51c678..17c152d22 100644
--- a/client/cmd/down.go
+++ b/client/cmd/down.go
@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err
}
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
diff --git a/client/cmd/login.go b/client/cmd/login.go
index 92de6abdb..3ac211805 100644
--- a/client/cmd/login.go
+++ b/client/cmd/login.go
@@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
}
// update host's static platform and system information
- system.UpdateStaticInfo()
+ system.UpdateStaticInfoAsync()
configFilePath, err := activeProf.FilePath()
if err != nil {
diff --git a/client/cmd/root.go b/client/cmd/root.go
index 290cae258..9f2eb109c 100644
--- a/client/cmd/root.go
+++ b/client/cmd/root.go
@@ -228,7 +228,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
- ctx, cancel := context.WithTimeout(ctx, time.Second*3)
+ ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
return grpc.DialContext(
diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go
index 50fb35d5e..0545ce6b7 100644
--- a/client/cmd/service_controller.go
+++ b/client/cmd/service_controller.go
@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
log.Info("starting NetBird service") //nolint
// Collect static system and platform information
- system.UpdateStaticInfo()
+ system.UpdateStaticInfoAsync()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer()
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index e45443751..99ccb1539 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -9,29 +9,26 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
+ "google.golang.org/grpc"
+ "github.com/netbirdio/management-integrations/integrations"
+ clientProto "github.com/netbirdio/netbird/client/proto"
+ client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
+ mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
+ "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
-
- "github.com/netbirdio/netbird/util"
-
- "google.golang.org/grpc"
-
- "github.com/netbirdio/management-integrations/integrations"
-
- clientProto "github.com/netbirdio/netbird/client/proto"
- client "github.com/netbirdio/netbird/client/server"
- mgmt "github.com/netbirdio/netbird/management/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server"
+ "github.com/netbirdio/netbird/util"
)
func startTestingServices(t *testing.T) string {
@@ -90,15 +87,20 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
return nil, nil
}
- iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
- metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
- require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
- settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
+ peersmanager := peers.NewManager(store, permissionsManagerMock)
+ settingsManagerMock := settings.NewMockManager(ctrl)
+
+ iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
+
+ metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
+ require.NoError(t, err)
+
+ settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT().
diff --git a/client/cmd/up.go b/client/cmd/up.go
index 61b442cea..1a553711d 100644
--- a/client/cmd/up.go
+++ b/client/cmd/up.go
@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn)
- status, err := client.Status(ctx, &proto.StatusRequest{})
+ status, err := client.Status(ctx, &proto.StatusRequest{
+ WaitForReady: func() *bool { b := true; return &b }(),
+ })
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
diff --git a/client/embed/embed.go b/client/embed/embed.go
index de83f9d96..0bfc7a37c 100644
--- a/client/embed/embed.go
+++ b/client/embed/embed.go
@@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
- run := make(chan struct{}, 1)
+ run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go
index 7b90000a8..ed8a7403b 100644
--- a/client/firewall/iptables/acl_linux.go
+++ b/client/firewall/iptables/acl_linux.go
@@ -12,7 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
const (
diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go
index d8e8857d4..343f5e05e 100644
--- a/client/firewall/iptables/router_linux.go
+++ b/client/firewall/iptables/router_linux.go
@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
// constants needed to manage and create iptable rules
diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go
index e9eeff863..3490c5dad 100644
--- a/client/firewall/iptables/router_linux_test.go
+++ b/client/firewall/iptables/router_linux_test.go
@@ -14,7 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
func isIptablesSupported() bool {
diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go
index 52979d257..9ff5b8c92 100644
--- a/client/firewall/nftables/acl_linux.go
+++ b/client/firewall/nftables/acl_linux.go
@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
const (
diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go
index aa4098821..0c091da96 100644
--- a/client/firewall/nftables/router_linux.go
+++ b/client/firewall/nftables/router_linux.go
@@ -22,7 +22,7 @@ import (
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
const (
diff --git a/util/grpc/dialer.go b/client/grpc/dialer.go
similarity index 91%
rename from util/grpc/dialer.go
rename to client/grpc/dialer.go
index f6d6d2f04..69e3f088c 100644
--- a/util/grpc/dialer.go
+++ b/client/grpc/dialer.go
@@ -20,8 +20,9 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
+ nbnet "github.com/netbirdio/netbird/client/net"
+
"github.com/netbirdio/netbird/util/embeddedroots"
- nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer() grpc.DialOption {
@@ -57,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
-func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
+func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
certPool, err := x509.SystemCertPool()
@@ -71,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
}))
}
- connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go
index 89bddf12c..32b07c330 100644
--- a/client/iface/bind/control.go
+++ b/client/iface/bind/control.go
@@ -3,7 +3,7 @@ package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go
index f23be406e..577c7c0c4 100644
--- a/client/iface/bind/ice_bind.go
+++ b/client/iface/bind/ice_bind.go
@@ -8,15 +8,16 @@ import (
"runtime"
"sync"
- "github.com/pion/stun/v2"
+ "github.com/pion/stun/v3"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
type RecvMessage struct {
@@ -44,7 +45,7 @@ type ICEBind struct {
RecvChan chan RecvMessage
transportNet transport.Net
- filterFn FilterFn
+ filterFn udpmux.FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
@@ -54,13 +55,13 @@ type ICEBind struct {
closed bool
muUDPMux sync.Mutex
- udpMux *UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
address wgaddr.Address
mtu uint16
activityRecorder *ActivityRecorder
}
-func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
+func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
@@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
-func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
+func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
@@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
- s.udpMux = NewUniversalUDPMuxDefault(
- UniversalUDPMuxParams{
+ s.udpMux = udpmux.NewUniversalUDPMuxDefault(
+ udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go
deleted file mode 100644
index db0249d11..000000000
--- a/client/iface/bind/udp_mux_ios.go
+++ /dev/null
@@ -1,7 +0,0 @@
-//go:build ios
-
-package bind
-
-func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
- // iOS doesn't support nbnet hooks, so this is a no-op
-}
diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go
index 171458e38..f744e0127 100644
--- a/client/iface/configurer/usp.go
+++ b/client/iface/configurer/usp.go
@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
+ nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime"
- nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
+
+ // If sec is 0 (Unix epoch), return zero time instead
+ // This indicates no handshake has occurred
+ if sec == 0 {
+ return time.Time{}, nil
+ }
+
return time.Unix(sec, 0), nil
}
@@ -402,7 +409,7 @@ func toBytes(s string) (int64, error) {
}
func getFwmark() int {
- if nbnet.AdvancedRouting() {
+ if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
return nbnet.ControlPlaneMark
}
return 0
diff --git a/client/iface/device.go b/client/iface/device.go
index ca6dda2c2..921f0ea98 100644
--- a/client/iface/device.go
+++ b/client/iface/device.go
@@ -7,14 +7,14 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
- "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
- Up() (*bind.UniversalUDPMuxDefault, error)
+ Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16
diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go
index fe3b9f82e..a731684cc 100644
--- a/client/iface/device/device_android.go
+++ b/client/iface/device/device_android.go
@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -29,7 +30,7 @@ type WGTunDevice struct {
name string
device *device.Device
filteredDevice *FilteredDevice
- udpMux *bind.UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
}
return t.configurer, nil
}
-func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go
index cce9d42df..390efe088 100644
--- a/client/iface/device/device_darwin.go
+++ b/client/iface/device/device_darwin.go
@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -26,7 +27,7 @@ type TunDevice struct {
device *device.Device
filteredDevice *FilteredDevice
- udpMux *bind.UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
-func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go
index 168985b5e..96e4c8bcf 100644
--- a/client/iface/device/device_ios.go
+++ b/client/iface/device/device_ios.go
@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -28,7 +29,7 @@ type TunDevice struct {
device *device.Device
filteredDevice *FilteredDevice
- udpMux *bind.UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
-func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go
index 00a72bcc6..cdac43a53 100644
--- a/client/iface/device/device_kernel_unix.go
+++ b/client/iface/device/device_kernel_unix.go
@@ -12,11 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
- "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
+ nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock"
- nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
link *wgLink
udpMuxConn net.PacketConn
- udpMux *bind.UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
- filterFn bind.FilterFn
+ filterFn udpmux.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil
}
-func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@@ -101,19 +101,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
- var udpConn net.PacketConn = rawSock
- if !nbnet.AdvancedRouting() {
- udpConn = nbnet.WrapPacketConn(rawSock)
- }
-
- bindParams := bind.UniversalUDPMuxParams{
- UDPConn: udpConn,
+ bindParams := udpmux.UniversalUDPMuxParams{
+ UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}
- mux := bind.NewUniversalUDPMuxDefault(bindParams)
+ mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock
t.udpMux = mux
diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go
index f41331ff7..a6ef47027 100644
--- a/client/iface/device/device_netstack.go
+++ b/client/iface/device/device_netstack.go
@@ -10,8 +10,9 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
type TunNetstackDevice struct {
@@ -26,7 +27,7 @@ type TunNetstackDevice struct {
device *device.Device
filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun
- udpMux *bind.UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
net *netstack.Net
@@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil
}
-func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go
index 8d30112ae..4cdd70a32 100644
--- a/client/iface/device/device_usp_unix.go
+++ b/client/iface/device/device_usp_unix.go
@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -25,7 +26,7 @@ type USPDevice struct {
device *device.Device
filteredDevice *FilteredDevice
- udpMux *bind.UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
-func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go
index de258868f..f1023bc0a 100644
--- a/client/iface/device/device_windows.go
+++ b/client/iface/device/device_windows.go
@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -29,7 +30,7 @@ type TunDevice struct {
device *device.Device
nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice
- udpMux *bind.UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
-func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
diff --git a/client/iface/device_android.go b/client/iface/device_android.go
index 39b5c28ae..4649b8b97 100644
--- a/client/iface/device_android.go
+++ b/client/iface/device_android.go
@@ -5,14 +5,14 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
- "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
- Up() (*bind.UniversalUDPMuxDefault, error)
+ Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16
diff --git a/client/iface/iface.go b/client/iface/iface.go
index 9a42223a1..609572561 100644
--- a/client/iface/iface.go
+++ b/client/iface/iface.go
@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors"
- "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
MTU uint16
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
- FilterFn bind.FilterFn
+ FilterFn udpmux.FilterFn
DisableDNS bool
}
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
-func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
w.mu.Lock()
defer w.mu.Unlock()
diff --git a/client/iface/bind/udp_muxed_conn.go b/client/iface/udpmux/conn.go
similarity index 95%
rename from client/iface/bind/udp_muxed_conn.go
rename to client/iface/udpmux/conn.go
index 7cacf1c31..3aa40caeb 100644
--- a/client/iface/bind/udp_muxed_conn.go
+++ b/client/iface/udpmux/conn.go
@@ -1,4 +1,4 @@
-package bind
+package udpmux
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,11 +16,12 @@ import (
)
type udpMuxedConnParams struct {
- Mux *UDPMuxDefault
- AddrPool *sync.Pool
- Key string
- LocalAddr net.Addr
- Logger logging.LeveledLogger
+ Mux *SingleSocketUDPMux
+ AddrPool *sync.Pool
+ Key string
+ LocalAddr net.Addr
+ Logger logging.LeveledLogger
+ CandidateID string
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
return err
}
+func (c *udpMuxedConn) GetCandidateID() string {
+ return c.params.CandidateID
+}
+
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
diff --git a/client/iface/udpmux/doc.go b/client/iface/udpmux/doc.go
new file mode 100644
index 000000000..27e5e43bc
--- /dev/null
+++ b/client/iface/udpmux/doc.go
@@ -0,0 +1,64 @@
+// Package udpmux provides a custom implementation of a UDP multiplexer
+// that allows multiple logical ICE connections to share a single underlying
+// UDP socket. This is based on Pion's ICE library, with modifications for
+// NetBird's requirements.
+//
+// # Background
+//
+// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
+// Establishment) is responsible for discovering candidate network paths
+// and maintaining connectivity between peers. Each ICE connection
+// normally requires a dedicated UDP socket. However, using one socket
+// per candidate can be inefficient and difficult to manage.
+//
+// This package introduces SingleSocketUDPMux, which allows multiple ICE
+// candidate connections (muxed connections) to share a single UDP socket.
+// It handles demultiplexing of packets based on ICE ufrag values, STUN
+// attributes, and candidate IDs.
+//
+// # Usage
+//
+// The typical flow is:
+//
+// 1. Create a UDP socket (net.PacketConn).
+// 2. Construct Params with the socket and optional logger/net stack.
+// 3. Call NewSingleSocketUDPMux(params).
+// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
+// to obtain a logical PacketConn.
+// 5. Use the returned PacketConn just like a normal UDP connection.
+//
+// # STUN Message Routing Logic
+//
+// When a STUN packet arrives, the mux decides which connection should
+// receive it using this routing logic:
+//
+// Primary Routing: Candidate Pair ID
+// - Extract the candidate pair ID from the STUN message using
+// ice.CandidatePairIDFromSTUN(msg)
+// - The target candidate is the locally generated candidate that
+// corresponds to the connection that should handle this STUN message
+// - If found, use the target candidate ID to lookup the specific
+// connection in candidateConnMap
+// - Route the message directly to that connection
+//
+// Fallback Routing: Broadcasting
+// When candidate pair ID is not available or lookup fails:
+// - Collect connections from addressMap based on source address
+// - Find connection using username attribute (ufrag) from STUN message
+// - Remove duplicate connections from the list
+// - Send the STUN message to all collected connections
+//
+// # Peer Reflexive Candidate Discovery
+//
+// When a remote peer sends a STUN message from an unknown source address
+// (from a candidate that has not been exchanged via signal), the ICE
+// library will:
+// - Generate a new peer reflexive candidate for this source address
+// - Extract or assign a candidate ID based on the STUN message attributes
+// - Create a mapping between the new peer reflexive candidate ID and
+// the appropriate local connection
+//
+// This discovery mechanism ensures that STUN messages from newly discovered
+// peer reflexive candidates can be properly routed to the correct local
+// connection without requiring fallback broadcasting.
+package udpmux
diff --git a/client/iface/bind/udp_mux.go b/client/iface/udpmux/mux.go
similarity index 65%
rename from client/iface/bind/udp_mux.go
rename to client/iface/udpmux/mux.go
index 29e5d7937..319724926 100644
--- a/client/iface/bind/udp_mux.go
+++ b/client/iface/udpmux/mux.go
@@ -1,4 +1,4 @@
-package bind
+package udpmux
import (
"fmt"
@@ -8,9 +8,9 @@ import (
"strings"
"sync"
- "github.com/pion/ice/v3"
+ "github.com/pion/ice/v4"
"github.com/pion/logging"
- "github.com/pion/stun/v2"
+ "github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192
-// UDPMuxDefault is an implementation of the interface
-type UDPMuxDefault struct {
- params UDPMuxParams
+// SingleSocketUDPMux is an implementation of the interface
+type SingleSocketUDPMux struct {
+ params Params
closedChan chan struct{}
closeOnce sync.Once
@@ -32,6 +32,9 @@ type UDPMuxDefault struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
+ // candidateConnMap maps local candidate IDs to their corresponding connection.
+ candidateConnMap map[string]*udpMuxedConn
+
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
@@ -46,8 +49,8 @@ type UDPMuxDefault struct {
const maxAddrSize = 512
-// UDPMuxParams are parameters for UDPMux.
-type UDPMuxParams struct {
+// Params are parameters for UDPMux.
+type Params struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
@@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool {
return true
}
-// NewUDPMuxDefault creates an implementation of UDPMux
-func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
+// NewSingleSocketUDPMux creates an implementation of UDPMux
+func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
if params.Logger == nil {
params.Logger = getLogger()
}
- mux := &UDPMuxDefault{
- addressMap: map[string][]*udpMuxedConn{},
- params: params,
- connsIPv4: make(map[string]*udpMuxedConn),
- connsIPv6: make(map[string]*udpMuxedConn),
- closedChan: make(chan struct{}, 1),
+ mux := &SingleSocketUDPMux{
+ addressMap: map[string][]*udpMuxedConn{},
+ params: params,
+ connsIPv4: make(map[string]*udpMuxedConn),
+ connsIPv6: make(map[string]*udpMuxedConn),
+ candidateConnMap: make(map[string]*udpMuxedConn),
+ closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
@@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return mux
}
-func (m *UDPMuxDefault) updateLocalAddresses() {
+func (m *SingleSocketUDPMux) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
- // with UDPMuxDefault, so print a warn log and create a local address list for mux.
- m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
+ // with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
+ m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
@@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
m.mu.Unlock()
}
-// LocalAddr returns the listening address of this UDPMuxDefault
-func (m *UDPMuxDefault) LocalAddr() net.Addr {
+// LocalAddr returns the listening address of this SingleSocketUDPMux
+func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
-func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
+func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
@@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
-func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
+func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return conn, nil
}
- c := m.createMuxedConn(ufrag)
+ c := m.createMuxedConn(ufrag, candidateID)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
+ m.candidateConnMap[candidateID] = c
+
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
@@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
}
// RemoveConnByUfrag stops and removes the muxed packet connection
-func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
+func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
@@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
+ delete(m.candidateConnMap, c.GetCandidateID())
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
+ delete(m.candidateConnMap, c.GetCandidateID())
}
m.mu.Unlock()
@@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
}
// IsClosed returns true if the mux had been closed
-func (m *UDPMuxDefault) IsClosed() bool {
+func (m *SingleSocketUDPMux) IsClosed() bool {
select {
case <-m.closedChan:
return true
@@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
}
// Close the mux, no further connections could be created
-func (m *UDPMuxDefault) Close() error {
+func (m *SingleSocketUDPMux) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
@@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error {
return err
}
-func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
+func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
-func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
+func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
@@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
-func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
+func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
- Mux: m,
- Key: key,
- AddrPool: m.pool,
- LocalAddr: m.LocalAddr(),
- Logger: m.params.Logger,
+ Mux: m,
+ Key: key,
+ AddrPool: m.pool,
+ LocalAddr: m.LocalAddr(),
+ Logger: m.params.Logger,
+ CandidateID: candidateID,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
-func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
-
+func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
- // If we have already seen this address dispatch to the appropriate destination
- // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
- // muxed connection - one for the SRFLX candidate and the other one for the HOST one.
- // We will then forward STUN packets to each of these connections.
- m.addressMapMu.RLock()
+ // Try to route to specific candidate connection first
+ if conn := m.findCandidateConnection(msg); conn != nil {
+ return conn.writePacket(msg.Raw, remoteAddr)
+ }
+
+ // Fallback: route to all possible connections
+ return m.forwardToAllConnections(msg, addr, remoteAddr)
+}
+
+// findCandidateConnection attempts to find the specific connection for a STUN message
+func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
+ candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
+ if err != nil {
+ return nil
+ } else if !ok {
+ return nil
+ }
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
+ if !exists {
+ return nil
+ }
+ return conn
+}
+
+// forwardToAllConnections forwards STUN message to all relevant connections
+func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn
+
+ // Add connections from address map
+ m.addressMapMu.RLock()
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.RUnlock()
- var isIPv6 bool
- if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
- isIPv6 = true
+ if conn, ok := m.findConnectionByUsername(msg, addr); ok {
+ // If we have already seen this address dispatch to the appropriate destination
+ // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
+ // muxed connection - one for the SRFLX candidate and the other one for the HOST one.
+ // We will then forward STUN packets to each of these connections.
+ if !m.connectionExists(conn, destinationConnList) {
+ destinationConnList = append(destinationConnList, conn)
+ }
}
- // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
- // However, we can take a username attribute from the STUN message which contains ufrag.
- // We can use ufrag to identify the destination conn to route packet to.
- attr, stunAttrErr := msg.Get(stun.AttrUsername)
- if stunAttrErr == nil {
- ufrag := strings.Split(string(attr), ":")[0]
-
- m.mu.Lock()
- destinationConn := m.connsIPv4[ufrag]
- if isIPv6 {
- destinationConn = m.connsIPv6[ufrag]
- }
-
- if destinationConn != nil {
- exists := false
- for _, conn := range destinationConnList {
- if conn.params.Key == destinationConn.params.Key {
- exists = true
- break
- }
- }
- if !exists {
- destinationConnList = append(destinationConnList, destinationConn)
- }
- }
- m.mu.Unlock()
- }
-
- // Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
- // It will be discarded by the further ICE candidate logic if so.
+ // Forward to all found connections
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
-
return nil
}
-func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
+// findConnectionByUsername finds connection using username attribute from STUN message
+func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
+ attr, err := msg.Get(stun.AttrUsername)
+ if err != nil {
+ return nil, false
+ }
+
+ ufrag := strings.Split(string(attr), ":")[0]
+ isIPv6 := isIPv6Address(addr)
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ return m.getConn(ufrag, isIPv6)
+}
+
+// connectionExists checks if a connection already exists in the list
+func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
+ for _, conn := range conns {
+ if conn.params.Key == target.params.Key {
+ return true
+ }
+ }
+ return false
+}
+
+func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
@@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
return
}
+func isIPv6Address(addr net.Addr) bool {
+ if udpAddr, ok := addr.(*net.UDPAddr); ok {
+ return udpAddr.IP.To4() == nil
+ }
+ return false
+}
+
type bufferHolder struct {
buf []byte
}
diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/udpmux/mux_generic.go
similarity index 76%
rename from client/iface/bind/udp_mux_generic.go
rename to client/iface/udpmux/mux_generic.go
index 63f786d2b..29fc2d834 100644
--- a/client/iface/bind/udp_mux_generic.go
+++ b/client/iface/udpmux/mux_generic.go
@@ -1,12 +1,12 @@
//go:build !ios
-package bind
+package udpmux
import (
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
-func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
+func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)
diff --git a/client/iface/udpmux/mux_ios.go b/client/iface/udpmux/mux_ios.go
new file mode 100644
index 000000000..4cf211d8f
--- /dev/null
+++ b/client/iface/udpmux/mux_ios.go
@@ -0,0 +1,7 @@
+//go:build ios
+
+package udpmux
+
+func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
+ // iOS doesn't support nbnet hooks, so this is a no-op
+}
diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/udpmux/universal.go
similarity index 96%
rename from client/iface/bind/udp_mux_universal.go
rename to client/iface/udpmux/universal.go
index b06da6712..43bfedaaa 100644
--- a/client/iface/bind/udp_mux_universal.go
+++ b/client/iface/udpmux/universal.go
@@ -1,4 +1,4 @@
-package bind
+package udpmux
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -15,7 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
- "github.com/pion/stun/v2"
+ "github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bufsize"
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
- *UDPMuxDefault
+ *SingleSocketUDPMux
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
- udpMuxParams := UDPMuxParams{
+ udpMuxParams := Params{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
- m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
+ m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
return m
}
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
-func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
- return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
+func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
+ return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
}
return nil
}
- return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
+ return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go
index fcdc0189d..b899f1694 100644
--- a/client/iface/wgproxy/ebpf/proxy.go
+++ b/client/iface/wgproxy/ebpf/proxy.go
@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
const (
diff --git a/client/internal/connect.go b/client/internal/connect.go
index a3cc7be1d..2bfa263fc 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -34,7 +34,7 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -275,7 +275,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engineMutex.Unlock()
- if err := c.engine.Start(); err != nil {
+ if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
@@ -284,10 +284,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected)
if runningChan != nil {
- select {
- case runningChan <- struct{}{}:
- default:
- }
+ close(runningChan)
+ runningChan = nil
}
<-engineCtx.Done()
diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go
new file mode 100644
index 000000000..cb651f1e5
--- /dev/null
+++ b/client/internal/dns/config/domains.go
@@ -0,0 +1,201 @@
+package config
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/netip"
+ "net/url"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/shared/management/domain"
+ mgmProto "github.com/netbirdio/netbird/shared/management/proto"
+)
+
+var (
+ ErrEmptyURL = errors.New("empty URL")
+ ErrEmptyHost = errors.New("empty host")
+ ErrIPNotAllowed = errors.New("IP address not allowed")
+)
+
+// ServerDomains represents the management server domains extracted from NetBird configuration
+type ServerDomains struct {
+ Signal domain.Domain
+ Relay []domain.Domain
+ Flow domain.Domain
+ Stuns []domain.Domain
+ Turns []domain.Domain
+}
+
+// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
+func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
+ if config == nil {
+ return ServerDomains{}
+ }
+
+ domains := ServerDomains{}
+
+ domains.Signal = extractSignalDomain(config)
+ domains.Relay = extractRelayDomains(config)
+ domains.Flow = extractFlowDomain(config)
+ domains.Stuns = extractStunDomains(config)
+ domains.Turns = extractTurnDomains(config)
+
+ return domains
+}
+
+// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses
+func ExtractValidDomain(rawURL string) (domain.Domain, error) {
+ if rawURL == "" {
+ return "", ErrEmptyURL
+ }
+
+ parsedURL, err := url.Parse(rawURL)
+ if err == nil {
+ if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" {
+ return domain, err
+ }
+ }
+
+ return extractFromRawString(rawURL)
+}
+
+// extractFromParsedURL handles domain extraction from successfully parsed URLs
+func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) {
+ if parsedURL.Hostname() != "" {
+ return extractDomainFromHost(parsedURL.Hostname())
+ }
+
+ if parsedURL.Opaque == "" || parsedURL.Scheme == "" {
+ return "", nil
+ }
+
+ // Handle URLs with opaque content (e.g., stun:host:port)
+ if strings.Contains(parsedURL.Scheme, ".") {
+ // This is likely "domain.com:port" being parsed as scheme:opaque
+ reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque
+ if host, _, err := net.SplitHostPort(reconstructed); err == nil {
+ return extractDomainFromHost(host)
+ }
+ return extractDomainFromHost(parsedURL.Scheme)
+ }
+
+ // Valid scheme with opaque content (e.g., stun:host:port)
+ host := parsedURL.Opaque
+ if queryIndex := strings.Index(host, "?"); queryIndex > 0 {
+ host = host[:queryIndex]
+ }
+
+ if hostOnly, _, err := net.SplitHostPort(host); err == nil {
+ return extractDomainFromHost(hostOnly)
+ }
+
+ return extractDomainFromHost(host)
+}
+
+// extractFromRawString handles domain extraction when URL parsing fails or returns no results
+func extractFromRawString(rawURL string) (domain.Domain, error) {
+ if host, _, err := net.SplitHostPort(rawURL); err == nil {
+ return extractDomainFromHost(host)
+ }
+
+ return extractDomainFromHost(rawURL)
+}
+
+// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
+func extractDomainFromHost(host string) (domain.Domain, error) {
+ if host == "" {
+ return "", ErrEmptyHost
+ }
+
+ if _, err := netip.ParseAddr(host); err == nil {
+ return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host)
+ }
+
+ d, err := domain.FromString(host)
+ if err != nil {
+ return "", fmt.Errorf("invalid domain: %v", err)
+ }
+
+ return d, nil
+}
+
+// extractSingleDomain extracts a single domain from a URL with error logging
+func extractSingleDomain(url, serviceType string) domain.Domain {
+ if url == "" {
+ return ""
+ }
+
+ d, err := ExtractValidDomain(url)
+ if err != nil {
+ log.Debugf("Skipping %s: %v", serviceType, err)
+ return ""
+ }
+
+ return d
+}
+
+// extractMultipleDomains extracts multiple domains from URLs with error logging
+func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
+ var domains []domain.Domain
+ for _, url := range urls {
+ if url == "" {
+ continue
+ }
+ d, err := ExtractValidDomain(url)
+ if err != nil {
+ log.Debugf("Skipping %s: %v", serviceType, err)
+ continue
+ }
+ domains = append(domains, d)
+ }
+ return domains
+}
+
+// extractSignalDomain extracts the signal domain from NetBird configuration.
+func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
+ if config.Signal != nil {
+ return extractSingleDomain(config.Signal.Uri, "signal")
+ }
+ return ""
+}
+
+// extractRelayDomains extracts relay server domains from NetBird configuration.
+func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
+ if config.Relay != nil {
+ return extractMultipleDomains(config.Relay.Urls, "relay")
+ }
+ return nil
+}
+
+// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
+func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
+ if config.Flow != nil {
+ return extractSingleDomain(config.Flow.Url, "flow")
+ }
+ return ""
+}
+
+// extractStunDomains extracts STUN server domains from NetBird configuration.
+func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
+ var urls []string
+ for _, stun := range config.Stuns {
+ if stun != nil && stun.Uri != "" {
+ urls = append(urls, stun.Uri)
+ }
+ }
+ return extractMultipleDomains(urls, "STUN")
+}
+
+// extractTurnDomains extracts TURN server domains from NetBird configuration.
+func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
+ var urls []string
+ for _, turn := range config.Turns {
+ if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
+ urls = append(urls, turn.HostConfig.Uri)
+ }
+ }
+ return extractMultipleDomains(urls, "TURN")
+}
diff --git a/client/internal/dns/config/domains_test.go b/client/internal/dns/config/domains_test.go
new file mode 100644
index 000000000..5eae3a541
--- /dev/null
+++ b/client/internal/dns/config/domains_test.go
@@ -0,0 +1,213 @@
+package config
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestExtractValidDomain(t *testing.T) {
+ tests := []struct {
+ name string
+ url string
+ expected string
+ expectError bool
+ }{
+ {
+ name: "HTTPS URL with port",
+ url: "https://api.netbird.io:443",
+ expected: "api.netbird.io",
+ },
+ {
+ name: "HTTP URL without port",
+ url: "http://signal.example.com",
+ expected: "signal.example.com",
+ },
+ {
+ name: "Host with port (no scheme)",
+ url: "signal.netbird.io:443",
+ expected: "signal.netbird.io",
+ },
+ {
+ name: "STUN URL",
+ url: "stun:stun.netbird.io:443",
+ expected: "stun.netbird.io",
+ },
+ {
+ name: "STUN URL with different port",
+ url: "stun:stun.netbird.io:5555",
+ expected: "stun.netbird.io",
+ },
+ {
+ name: "TURNS URL with query params",
+ url: "turns:turn.netbird.io:443?transport=tcp",
+ expected: "turn.netbird.io",
+ },
+ {
+ name: "TURN URL",
+ url: "turn:turn.example.com:3478",
+ expected: "turn.example.com",
+ },
+ {
+ name: "REL URL",
+ url: "rel://relay.example.com:443",
+ expected: "relay.example.com",
+ },
+ {
+ name: "RELS URL",
+ url: "rels://relay.netbird.io:443",
+ expected: "relay.netbird.io",
+ },
+ {
+ name: "Raw hostname",
+ url: "example.org",
+ expected: "example.org",
+ },
+ {
+ name: "IP address should be rejected",
+ url: "192.168.1.1",
+ expectError: true,
+ },
+ {
+ name: "IP address with port should be rejected",
+ url: "192.168.1.1:443",
+ expectError: true,
+ },
+ {
+ name: "IPv6 address should be rejected",
+ url: "2001:db8::1",
+ expectError: true,
+ },
+ {
+ name: "HTTP URL with IPv4 should be rejected",
+ url: "http://192.168.1.1:8080",
+ expectError: true,
+ },
+ {
+ name: "HTTPS URL with IPv4 should be rejected",
+ url: "https://10.0.0.1:443",
+ expectError: true,
+ },
+ {
+ name: "STUN URL with IPv4 should be rejected",
+ url: "stun:192.168.1.1:3478",
+ expectError: true,
+ },
+ {
+ name: "TURN URL with IPv4 should be rejected",
+ url: "turn:10.0.0.1:3478",
+ expectError: true,
+ },
+ {
+ name: "TURNS URL with IPv4 should be rejected",
+ url: "turns:172.16.0.1:5349",
+ expectError: true,
+ },
+ {
+ name: "HTTP URL with IPv6 should be rejected",
+ url: "http://[2001:db8::1]:8080",
+ expectError: true,
+ },
+ {
+ name: "HTTPS URL with IPv6 should be rejected",
+ url: "https://[::1]:443",
+ expectError: true,
+ },
+ {
+ name: "STUN URL with IPv6 should be rejected",
+ url: "stun:[2001:db8::1]:3478",
+ expectError: true,
+ },
+ {
+ name: "IPv6 with port should be rejected",
+ url: "[2001:db8::1]:443",
+ expectError: true,
+ },
+ {
+ name: "Localhost IPv4 should be rejected",
+ url: "127.0.0.1:8080",
+ expectError: true,
+ },
+ {
+ name: "Localhost IPv6 should be rejected",
+ url: "[::1]:443",
+ expectError: true,
+ },
+ {
+ name: "REL URL with IPv4 should be rejected",
+ url: "rel://192.168.1.1:443",
+ expectError: true,
+ },
+ {
+ name: "RELS URL with IPv4 should be rejected",
+ url: "rels://10.0.0.1:443",
+ expectError: true,
+ },
+ {
+ name: "Empty URL",
+ url: "",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := ExtractValidDomain(tt.url)
+
+ if tt.expectError {
+ assert.Error(t, err, "Expected error for URL: %s", tt.url)
+ } else {
+ assert.NoError(t, err, "Unexpected error for URL: %s", tt.url)
+ assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url)
+ }
+ })
+ }
+}
+
+func TestExtractDomainFromHost(t *testing.T) {
+ tests := []struct {
+ name string
+ host string
+ expected string
+ expectError bool
+ }{
+ {
+ name: "Valid domain",
+ host: "example.com",
+ expected: "example.com",
+ },
+ {
+ name: "Subdomain",
+ host: "api.example.com",
+ expected: "api.example.com",
+ },
+ {
+ name: "IPv4 address",
+ host: "192.168.1.1",
+ expectError: true,
+ },
+ {
+ name: "IPv6 address",
+ host: "2001:db8::1",
+ expectError: true,
+ },
+ {
+ name: "Empty host",
+ host: "",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := extractDomainFromHost(tt.host)
+
+ if tt.expectError {
+ assert.Error(t, err, "Expected error for host: %s", tt.host)
+ } else {
+ assert.NoError(t, err, "Unexpected error for host: %s", tt.host)
+ assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host)
+ }
+ })
+ }
+}
diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go
index 439bcbb3c..2e54bffd9 100644
--- a/client/internal/dns/handler_chain.go
+++ b/client/internal/dns/handler_chain.go
@@ -11,11 +11,12 @@ import (
)
const (
- PriorityLocal = 100
- PriorityDNSRoute = 75
- PriorityUpstream = 50
- PriorityDefault = 1
- PriorityFallback = -100
+ PriorityMgmtCache = 150
+ PriorityLocal = 100
+ PriorityDNSRoute = 75
+ PriorityUpstream = 50
+ PriorityDefault = 1
+ PriorityFallback = -100
)
type SubdomainMatcher interface {
@@ -182,7 +183,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
- log.Tracef("handler requested continue to next handler for domain=%s", qname)
+ // Only log continue for non-management cache handlers to reduce noise
+ if entry.Priority != PriorityMgmtCache {
+ log.Tracef("handler requested continue to next handler for domain=%s", qname)
+ }
continue
}
return
diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go
index 852dfef48..b06ba73ab 100644
--- a/client/internal/dns/host_darwin.go
+++ b/client/internal/dns/host_darwin.go
@@ -166,9 +166,10 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
- err := s.recordSystemDNSSettings(true)
- log.Errorf("Unable to get system DNS configuration")
- return err
+ if err := s.recordSystemDNSSettings(true); err != nil {
+ log.Errorf("Unable to get system DNS configuration")
+ return fmt.Errorf("recordSystemDNSSettings(): %w", err)
+ }
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go
index b776fbbe3..bac7875ec 100644
--- a/client/internal/dns/local/local.go
+++ b/client/internal/dns/local/local.go
@@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool {
// String returns a string representation of the local resolver
func (d *Resolver) String() string {
- return fmt.Sprintf("local resolver [%d records]", len(d.records))
+ return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go
new file mode 100644
index 000000000..290395473
--- /dev/null
+++ b/client/internal/dns/mgmt/mgmt.go
@@ -0,0 +1,360 @@
+package mgmt
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/miekg/dns"
+ log "github.com/sirupsen/logrus"
+
+ dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
+ "github.com/netbirdio/netbird/shared/management/domain"
+)
+
+const dnsTimeout = 5 * time.Second
+
+// Resolver caches critical NetBird infrastructure domains
+type Resolver struct {
+ records map[dns.Question][]dns.RR
+ mgmtDomain *domain.Domain
+ serverDomains *dnsconfig.ServerDomains
+ mutex sync.RWMutex
+}
+
+// NewResolver creates a new management domains cache resolver.
+func NewResolver() *Resolver {
+ return &Resolver{
+ records: make(map[dns.Question][]dns.RR),
+ }
+}
+
+// String returns a string representation of the resolver.
+func (m *Resolver) String() string {
+ return "MgmtCacheResolver"
+}
+
+// ServeDNS implements dns.Handler interface.
+func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+ if len(r.Question) == 0 {
+ m.continueToNext(w, r)
+ return
+ }
+
+ question := r.Question[0]
+ question.Name = strings.ToLower(dns.Fqdn(question.Name))
+
+ if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
+ m.continueToNext(w, r)
+ return
+ }
+
+ m.mutex.RLock()
+ records, found := m.records[question]
+ m.mutex.RUnlock()
+
+ if !found {
+ m.continueToNext(w, r)
+ return
+ }
+
+ resp := &dns.Msg{}
+ resp.SetReply(r)
+ resp.Authoritative = false
+ resp.RecursionAvailable = true
+
+ resp.Answer = append(resp.Answer, records...)
+
+ log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
+
+ if err := w.WriteMsg(resp); err != nil {
+ log.Errorf("failed to write response: %v", err)
+ }
+}
+
+// MatchSubdomains returns false since this resolver only handles exact domain matches
+// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
+func (m *Resolver) MatchSubdomains() bool {
+ return false
+}
+
+// continueToNext signals the handler chain to continue to the next handler.
+func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
+ resp := &dns.Msg{}
+ resp.SetRcode(r, dns.RcodeNameError)
+ resp.MsgHdr.Zero = true
+ if err := w.WriteMsg(resp); err != nil {
+ log.Errorf("failed to write continue signal: %v", err)
+ }
+}
+
+// AddDomain manually adds a domain to cache by resolving it.
+func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
+ dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
+
+ ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
+ defer cancel()
+
+ ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
+ if err != nil {
+ return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
+ }
+
+ var aRecords, aaaaRecords []dns.RR
+ for _, ip := range ips {
+ if ip.Is4() {
+ rr := &dns.A{
+ Hdr: dns.RR_Header{
+ Name: dnsName,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ Ttl: 300,
+ },
+ A: ip.AsSlice(),
+ }
+ aRecords = append(aRecords, rr)
+ } else if ip.Is6() {
+ rr := &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: dnsName,
+ Rrtype: dns.TypeAAAA,
+ Class: dns.ClassINET,
+ Ttl: 300,
+ },
+ AAAA: ip.AsSlice(),
+ }
+ aaaaRecords = append(aaaaRecords, rr)
+ }
+ }
+
+ m.mutex.Lock()
+
+ if len(aRecords) > 0 {
+ aQuestion := dns.Question{
+ Name: dnsName,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }
+ m.records[aQuestion] = aRecords
+ }
+
+ if len(aaaaRecords) > 0 {
+ aaaaQuestion := dns.Question{
+ Name: dnsName,
+ Qtype: dns.TypeAAAA,
+ Qclass: dns.ClassINET,
+ }
+ m.records[aaaaQuestion] = aaaaRecords
+ }
+
+ m.mutex.Unlock()
+
+ log.Debugf("added domain=%s with %d A records and %d AAAA records",
+ d.SafeString(), len(aRecords), len(aaaaRecords))
+
+ return nil
+}
+
+// PopulateFromConfig extracts and caches domains from the client configuration.
+func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
+ if mgmtURL == nil {
+ return nil
+ }
+
+ d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
+ if err != nil {
+ return fmt.Errorf("extract domain from URL: %w", err)
+ }
+
+ m.mutex.Lock()
+ m.mgmtDomain = &d
+ m.mutex.Unlock()
+
+ if err := m.AddDomain(ctx, d); err != nil {
+ return fmt.Errorf("add domain: %w", err)
+ }
+
+ return nil
+}
+
+// RemoveDomain removes a domain from the cache.
+func (m *Resolver) RemoveDomain(d domain.Domain) error {
+ dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
+
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ aQuestion := dns.Question{
+ Name: dnsName,
+ Qtype: dns.TypeA,
+ Qclass: dns.ClassINET,
+ }
+ delete(m.records, aQuestion)
+
+ aaaaQuestion := dns.Question{
+ Name: dnsName,
+ Qtype: dns.TypeAAAA,
+ Qclass: dns.ClassINET,
+ }
+ delete(m.records, aaaaQuestion)
+
+ log.Debugf("removed domain=%s from cache", d.SafeString())
+ return nil
+}
+
+// GetCachedDomains returns a list of all cached domains.
+func (m *Resolver) GetCachedDomains() domain.List {
+ m.mutex.RLock()
+ defer m.mutex.RUnlock()
+
+ domainSet := make(map[domain.Domain]struct{})
+ for question := range m.records {
+ domainName := strings.TrimSuffix(question.Name, ".")
+ domainSet[domain.Domain(domainName)] = struct{}{}
+ }
+
+ domains := make(domain.List, 0, len(domainSet))
+ for d := range domainSet {
+ domains = append(domains, d)
+ }
+
+ return domains
+}
+
+// UpdateFromServerDomains updates the cache with server domains from network configuration.
+// It merges new domains with existing ones, replacing entire domain types when updated.
+// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
+func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
+ newDomains := m.extractDomainsFromServerDomains(serverDomains)
+ var removedDomains domain.List
+
+ if len(newDomains) > 0 {
+ m.mutex.Lock()
+ if m.serverDomains == nil {
+ m.serverDomains = &dnsconfig.ServerDomains{}
+ }
+ updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains)
+ m.serverDomains = &updatedServerDomains
+ m.mutex.Unlock()
+
+ allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
+ currentDomains := m.GetCachedDomains()
+ removedDomains = m.removeStaleDomains(currentDomains, allDomains)
+ }
+
+ m.addNewDomains(ctx, newDomains)
+
+ return removedDomains, nil
+}
+
+// removeStaleDomains removes cached domains not present in the target domain list.
+// Management domains are preserved and never removed during server domain updates.
+func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
+ var removedDomains domain.List
+
+ for _, currentDomain := range currentDomains {
+ if m.isDomainInList(currentDomain, newDomains) {
+ continue
+ }
+
+ if m.isManagementDomain(currentDomain) {
+ continue
+ }
+
+ removedDomains = append(removedDomains, currentDomain)
+ if err := m.RemoveDomain(currentDomain); err != nil {
+ log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
+ }
+ }
+
+ return removedDomains
+}
+
+// mergeServerDomains merges new server domains with existing ones.
+// When a domain type is provided in the new domains, it completely replaces that type.
+func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains {
+ merged := existing
+
+ if incoming.Signal != "" {
+ merged.Signal = incoming.Signal
+ }
+ if len(incoming.Relay) > 0 {
+ merged.Relay = incoming.Relay
+ }
+ if incoming.Flow != "" {
+ merged.Flow = incoming.Flow
+ }
+ if len(incoming.Stuns) > 0 {
+ merged.Stuns = incoming.Stuns
+ }
+ if len(incoming.Turns) > 0 {
+ merged.Turns = incoming.Turns
+ }
+
+ return merged
+}
+
+// isDomainInList checks if domain exists in the list
+func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
+ for _, d := range list {
+ if domain.SafeString() == d.SafeString() {
+ return true
+ }
+ }
+ return false
+}
+
+// isManagementDomain checks if domain is the protected management domain
+func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
+ m.mutex.RLock()
+ defer m.mutex.RUnlock()
+
+ return m.mgmtDomain != nil && domain == *m.mgmtDomain
+}
+
+// addNewDomains resolves and caches all domains from the update
+func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
+ for _, newDomain := range newDomains {
+ if err := m.AddDomain(ctx, newDomain); err != nil {
+ log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
+ } else {
+ log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
+ }
+ }
+}
+
+func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
+ var domains domain.List
+
+ if serverDomains.Signal != "" {
+ domains = append(domains, serverDomains.Signal)
+ }
+
+ for _, relay := range serverDomains.Relay {
+ if relay != "" {
+ domains = append(domains, relay)
+ }
+ }
+
+ if serverDomains.Flow != "" {
+ domains = append(domains, serverDomains.Flow)
+ }
+
+ for _, stun := range serverDomains.Stuns {
+ if stun != "" {
+ domains = append(domains, stun)
+ }
+ }
+
+ for _, turn := range serverDomains.Turns {
+ if turn != "" {
+ domains = append(domains, turn)
+ }
+ }
+
+ return domains
+}
diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go
new file mode 100644
index 000000000..99d289871
--- /dev/null
+++ b/client/internal/dns/mgmt/mgmt_test.go
@@ -0,0 +1,416 @@
+package mgmt
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+
+ dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
+ "github.com/netbirdio/netbird/client/internal/dns/test"
+ "github.com/netbirdio/netbird/shared/management/domain"
+)
+
+func TestResolver_NewResolver(t *testing.T) {
+ resolver := NewResolver()
+
+ assert.NotNil(t, resolver)
+ assert.NotNil(t, resolver.records)
+ assert.False(t, resolver.MatchSubdomains())
+}
+
+func TestResolver_ExtractDomainFromURL(t *testing.T) {
+ tests := []struct {
+ name string
+ urlStr string
+ expectedDom string
+ expectError bool
+ }{
+ {
+ name: "HTTPS URL with port",
+ urlStr: "https://api.netbird.io:443",
+ expectedDom: "api.netbird.io",
+ expectError: false,
+ },
+ {
+ name: "HTTP URL without port",
+ urlStr: "http://signal.example.com",
+ expectedDom: "signal.example.com",
+ expectError: false,
+ },
+ {
+ name: "URL with path",
+ urlStr: "https://relay.netbird.io/status",
+ expectedDom: "relay.netbird.io",
+ expectError: false,
+ },
+ {
+ name: "Invalid URL",
+ urlStr: "not-a-valid-url",
+ expectedDom: "not-a-valid-url",
+ expectError: false,
+ },
+ {
+ name: "Empty URL",
+ urlStr: "",
+ expectedDom: "",
+ expectError: true,
+ },
+ {
+ name: "STUN URL",
+ urlStr: "stun:stun.example.com:3478",
+ expectedDom: "stun.example.com",
+ expectError: false,
+ },
+ {
+ name: "TURN URL",
+ urlStr: "turn:turn.example.com:3478",
+ expectedDom: "turn.example.com",
+ expectError: false,
+ },
+ {
+ name: "REL URL",
+ urlStr: "rel://relay.example.com:443",
+ expectedDom: "relay.example.com",
+ expectError: false,
+ },
+ {
+ name: "RELS URL",
+ urlStr: "rels://relay.example.com:443",
+ expectedDom: "relay.example.com",
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var parsedURL *url.URL
+ var err error
+
+ if tt.urlStr != "" {
+ parsedURL, err = url.Parse(tt.urlStr)
+ if err != nil && !tt.expectError {
+ t.Fatalf("Failed to parse URL: %v", err)
+ }
+ }
+
+ domain, err := extractDomainFromURL(parsedURL)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expectedDom, domain.SafeString())
+ }
+ })
+ }
+}
+
+func TestResolver_PopulateFromConfig(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ resolver := NewResolver()
+
+ // Test with IP address - should return error since IP addresses are rejected
+ mgmtURL, _ := url.Parse("https://127.0.0.1")
+
+ err := resolver.PopulateFromConfig(ctx, mgmtURL)
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed)
+
+ // No domains should be cached when using IP addresses
+ domains := resolver.GetCachedDomains()
+ assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
+}
+
+func TestResolver_ServeDNS(t *testing.T) {
+ resolver := NewResolver()
+ ctx := context.Background()
+
+ // Add a test domain to the cache - use example.org which is reserved for testing
+ testDomain, err := domain.FromString("example.org")
+ if err != nil {
+ t.Fatalf("Failed to create domain: %v", err)
+ }
+ err = resolver.AddDomain(ctx, testDomain)
+ if err != nil {
+ t.Skipf("Skipping test due to DNS resolution failure: %v", err)
+ }
+
+ // Test A record query for cached domain
+ t.Run("Cached domain A record", func(t *testing.T) {
+ var capturedMsg *dns.Msg
+ mockWriter := &test.MockResponseWriter{
+ WriteMsgFunc: func(m *dns.Msg) error {
+ capturedMsg = m
+ return nil
+ },
+ }
+
+ req := new(dns.Msg)
+ req.SetQuestion("example.org.", dns.TypeA)
+
+ resolver.ServeDNS(mockWriter, req)
+
+ assert.NotNil(t, capturedMsg)
+ assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
+ assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer")
+ })
+
+ // Test uncached domain signals to continue to next handler
+ t.Run("Uncached domain signals continue to next handler", func(t *testing.T) {
+ var capturedMsg *dns.Msg
+ mockWriter := &test.MockResponseWriter{
+ WriteMsgFunc: func(m *dns.Msg) error {
+ capturedMsg = m
+ return nil
+ },
+ }
+
+ req := new(dns.Msg)
+ req.SetQuestion("unknown.example.com.", dns.TypeA)
+
+ resolver.ServeDNS(mockWriter, req)
+
+ assert.NotNil(t, capturedMsg)
+ assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
+ // Zero flag set to true signals the handler chain to continue to next handler
+ assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler")
+ assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain")
+ })
+
+ // Test that subdomains of cached domains are NOT resolved
+ t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) {
+ var capturedMsg *dns.Msg
+ mockWriter := &test.MockResponseWriter{
+ WriteMsgFunc: func(m *dns.Msg) error {
+ capturedMsg = m
+ return nil
+ },
+ }
+
+ // Query for a subdomain of our cached domain
+ req := new(dns.Msg)
+ req.SetQuestion("sub.example.org.", dns.TypeA)
+
+ resolver.ServeDNS(mockWriter, req)
+
+ assert.NotNil(t, capturedMsg)
+ assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
+ assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains")
+ assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains")
+ })
+
+ // Test case-insensitive matching
+ t.Run("Case-insensitive domain matching", func(t *testing.T) {
+ var capturedMsg *dns.Msg
+ mockWriter := &test.MockResponseWriter{
+ WriteMsgFunc: func(m *dns.Msg) error {
+ capturedMsg = m
+ return nil
+ },
+ }
+
+ // Query with different casing
+ req := new(dns.Msg)
+ req.SetQuestion("EXAMPLE.ORG.", dns.TypeA)
+
+ resolver.ServeDNS(mockWriter, req)
+
+ assert.NotNil(t, capturedMsg)
+ assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
+ assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case")
+ })
+}
+
+func TestResolver_GetCachedDomains(t *testing.T) {
+ resolver := NewResolver()
+ ctx := context.Background()
+
+ testDomain, err := domain.FromString("example.org")
+ if err != nil {
+ t.Fatalf("Failed to create domain: %v", err)
+ }
+ err = resolver.AddDomain(ctx, testDomain)
+ if err != nil {
+ t.Skipf("Skipping test due to DNS resolution failure: %v", err)
+ }
+
+ cachedDomains := resolver.GetCachedDomains()
+
+ assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain")
+ assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original")
+ assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot")
+}
+
+func TestResolver_ManagementDomainProtection(t *testing.T) {
+ resolver := NewResolver()
+ ctx := context.Background()
+
+ mgmtURL, _ := url.Parse("https://example.org")
+ err := resolver.PopulateFromConfig(ctx, mgmtURL)
+ if err != nil {
+ t.Skipf("Skipping test due to DNS resolution failure: %v", err)
+ }
+
+ initialDomains := resolver.GetCachedDomains()
+ if len(initialDomains) == 0 {
+ t.Skip("Management domain failed to resolve, skipping test")
+ }
+ assert.Equal(t, 1, len(initialDomains), "Should have management domain cached")
+ assert.Equal(t, "example.org", initialDomains[0].SafeString())
+
+ serverDomains := dnsconfig.ServerDomains{
+ Signal: "google.com",
+ Relay: []domain.Domain{"cloudflare.com"},
+ }
+
+ _, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
+ if err != nil {
+ t.Logf("Server domains update failed: %v", err)
+ }
+
+ finalDomains := resolver.GetCachedDomains()
+
+ managementStillCached := false
+ for _, d := range finalDomains {
+ if d.SafeString() == "example.org" {
+ managementStillCached = true
+ break
+ }
+ }
+ assert.True(t, managementStillCached, "Management domain should never be removed")
+}
+
+// extractDomainFromURL extracts a domain from a URL - test helper function
+func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
+ if u == nil {
+ return "", fmt.Errorf("URL is nil")
+ }
+ return dnsconfig.ExtractValidDomain(u.String())
+}
+
+func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
+ resolver := NewResolver()
+ ctx := context.Background()
+
+ // Set up initial domains using resolvable domains
+ initialDomains := dnsconfig.ServerDomains{
+ Signal: "example.org",
+ Stuns: []domain.Domain{"google.com"},
+ Turns: []domain.Domain{"cloudflare.com"},
+ }
+
+ // Add initial domains
+ _, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
+ if err != nil {
+ t.Skipf("Skipping test due to DNS resolution failure: %v", err)
+ }
+
+ // Verify domains were added
+ cachedDomains := resolver.GetCachedDomains()
+ assert.Len(t, cachedDomains, 3)
+
+ // Update with empty ServerDomains (simulating partial network map update)
+ emptyDomains := dnsconfig.ServerDomains{}
+ removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains)
+ assert.NoError(t, err)
+
+ // Verify no domains were removed
+ assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty")
+
+ // Verify all original domains are still cached
+ finalDomains := resolver.GetCachedDomains()
+ assert.Len(t, finalDomains, 3, "All original domains should still be cached")
+}
+
+func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
+ resolver := NewResolver()
+ ctx := context.Background()
+
+ // Set up initial complete domains using resolvable domains
+ initialDomains := dnsconfig.ServerDomains{
+ Signal: "example.org",
+ Stuns: []domain.Domain{"google.com"},
+ Turns: []domain.Domain{"cloudflare.com"},
+ }
+
+ // Add initial domains
+ _, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
+ if err != nil {
+ t.Skipf("Skipping test due to DNS resolution failure: %v", err)
+ }
+ assert.Len(t, resolver.GetCachedDomains(), 3)
+
+ // Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
+ partialDomains := dnsconfig.ServerDomains{
+ Signal: "github.com",
+ }
+ removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
+ if err != nil {
+ t.Skipf("Skipping test due to DNS resolution failure: %v", err)
+ }
+
+ // Should remove only the old signal domain
+ assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
+ assert.Equal(t, "example.org", removedDomains[0].SafeString())
+
+ finalDomains := resolver.GetCachedDomains()
+ assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains")
+
+ domainStrings := make([]string, len(finalDomains))
+ for i, d := range finalDomains {
+ domainStrings[i] = d.SafeString()
+ }
+ assert.Contains(t, domainStrings, "github.com")
+ assert.Contains(t, domainStrings, "google.com")
+ assert.Contains(t, domainStrings, "cloudflare.com")
+ assert.NotContains(t, domainStrings, "example.org")
+}
+
+func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
+ resolver := NewResolver()
+ ctx := context.Background()
+
+ // Set up initial complete domains using resolvable domains
+ initialDomains := dnsconfig.ServerDomains{
+ Signal: "example.org",
+ Stuns: []domain.Domain{"google.com"},
+ Turns: []domain.Domain{"cloudflare.com"},
+ }
+
+ // Add initial domains
+ _, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
+ if err != nil {
+ t.Skipf("Skipping test due to DNS resolution failure: %v", err)
+ }
+ assert.Len(t, resolver.GetCachedDomains(), 3)
+
+ // Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
+ partialDomains := dnsconfig.ServerDomains{
+ Flow: "github.com",
+ }
+ removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
+ if err != nil {
+ t.Skipf("Skipping test due to DNS resolution failure: %v", err)
+ }
+
+ assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
+
+ finalDomains := resolver.GetCachedDomains()
+ assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
+
+ domainStrings := make([]string, len(finalDomains))
+ for i, d := range finalDomains {
+ domainStrings[i] = d.SafeString()
+ }
+ assert.Contains(t, domainStrings, "example.org")
+ assert.Contains(t, domainStrings, "google.com")
+ assert.Contains(t, domainStrings, "cloudflare.com")
+ assert.Contains(t, domainStrings, "github.com")
+}
diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go
index d160fa99a..0f89b9016 100644
--- a/client/internal/dns/mock_server.go
+++ b/client/internal/dns/mock_server.go
@@ -3,20 +3,23 @@ package dns
import (
"fmt"
"net/netip"
+ "net/url"
"github.com/miekg/dns"
+ dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
)
// MockServer is the mock instance of a dns server
type MockServer struct {
- InitializeFunc func() error
- StopFunc func()
- UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
- RegisterHandlerFunc func(domain.List, dns.Handler, int)
- DeregisterHandlerFunc func(domain.List, int)
+ InitializeFunc func() error
+ StopFunc func()
+ UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
+ RegisterHandlerFunc func(domain.List, dns.Handler, int)
+ DeregisterHandlerFunc func(domain.List, int)
+ UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
}
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -70,3 +73,14 @@ func (m *MockServer) SearchDomains() []string {
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() {
}
+
+func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
+ if m.UpdateServerConfigFunc != nil {
+ return m.UpdateServerConfigFunc(domains)
+ }
+ return nil
+}
+
+func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
+ return nil
+}
diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go
index cbcf6a256..8cb886203 100644
--- a/client/internal/dns/server.go
+++ b/client/internal/dns/server.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/netip"
+ "net/url"
"runtime"
"strings"
"sync"
@@ -15,7 +16,9 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack"
+ dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local"
+ "github.com/netbirdio/netbird/client/internal/dns/mgmt"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -45,6 +48,8 @@ type Server interface {
OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string
ProbeAvailability()
+ UpdateServerConfig(domains dnsconfig.ServerDomains) error
+ PopulateManagementDomain(mgmtURL *url.URL) error
}
type nsGroupsByDomain struct {
@@ -77,6 +82,8 @@ type DefaultServer struct {
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
+ mgmtCacheResolver *mgmt.Resolver
+
// permanent related properties
permanent bool
hostsDNSHolder *hostsDNSHolder
@@ -104,18 +111,20 @@ type handlerWrapper struct {
type registeredHandlerMap map[types.HandlerID]handlerWrapper
+// DefaultServerConfig holds configuration parameters for NewDefaultServer
+type DefaultServerConfig struct {
+ WgInterface WGIface
+ CustomAddress string
+ StatusRecorder *peer.Status
+ StateManager *statemanager.Manager
+ DisableSys bool
+}
+
// NewDefaultServer returns a new dns server
-func NewDefaultServer(
- ctx context.Context,
- wgInterface WGIface,
- customAddress string,
- statusRecorder *peer.Status,
- stateManager *statemanager.Manager,
- disableSys bool,
-) (*DefaultServer, error) {
+func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) {
var addrPort *netip.AddrPort
- if customAddress != "" {
- parsedAddrPort, err := netip.ParseAddrPort(customAddress)
+ if config.CustomAddress != "" {
+ parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
if err != nil {
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
}
@@ -123,13 +132,14 @@ func NewDefaultServer(
}
var dnsService service
- if wgInterface.IsUserspaceBind() {
- dnsService = NewServiceViaMemory(wgInterface)
+ if config.WgInterface.IsUserspaceBind() {
+ dnsService = NewServiceViaMemory(config.WgInterface)
} else {
- dnsService = newServiceViaListener(wgInterface, addrPort)
+ dnsService = newServiceViaListener(config.WgInterface, addrPort)
}
- return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil
+ server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
+ return server, nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@@ -178,20 +188,24 @@ func newDefaultServer(
) *DefaultServer {
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx)
+
+ mgmtCacheResolver := mgmt.NewResolver()
+
defaultServer := &DefaultServer{
- ctx: ctx,
- ctxCancel: stop,
- disableSys: disableSys,
- service: dnsService,
- handlerChain: handlerChain,
- extraDomains: make(map[domain.Domain]int),
- dnsMuxMap: make(registeredHandlerMap),
- localResolver: local.NewResolver(),
- wgInterface: wgInterface,
- statusRecorder: statusRecorder,
- stateManager: stateManager,
- hostsDNSHolder: newHostsDNSHolder(),
- hostManager: &noopHostConfigurator{},
+ ctx: ctx,
+ ctxCancel: stop,
+ disableSys: disableSys,
+ service: dnsService,
+ handlerChain: handlerChain,
+ extraDomains: make(map[domain.Domain]int),
+ dnsMuxMap: make(registeredHandlerMap),
+ localResolver: local.NewResolver(),
+ wgInterface: wgInterface,
+ statusRecorder: statusRecorder,
+ stateManager: stateManager,
+ hostsDNSHolder: newHostsDNSHolder(),
+ hostManager: &noopHostConfigurator{},
+ mgmtCacheResolver: mgmtCacheResolver,
}
// register with root zone, handler chain takes care of the routing
@@ -217,7 +231,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
- log.Debugf("registering handler %s with priority %d", handler, priority)
+ log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains)
for _, domain := range domains {
if domain == "" {
@@ -246,7 +260,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
- log.Debugf("deregistering handler %v with priority %d", domains, priority)
+ log.Debugf("deregistering handler with priority %d for %v", priority, domains)
for _, domain := range domains {
if domain == "" {
@@ -432,6 +446,29 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Wait()
}
+func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ if s.mgmtCacheResolver != nil {
+ removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
+ if err != nil {
+ return fmt.Errorf("update management cache resolver: %w", err)
+ }
+
+ if len(removedDomains) > 0 {
+ s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
+ }
+
+ newDomains := s.mgmtCacheResolver.GetCachedDomains()
+ if len(newDomains) > 0 {
+ s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
+ }
+ }
+
+ return nil
+}
+
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver
if update.ServiceEnable {
@@ -961,3 +998,11 @@ func toZone(d domain.Domain) domain.Domain {
),
)
}
+
+// PopulateManagementDomain populates the DNS cache with management domain
+func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
+ if s.mgmtCacheResolver != nil {
+ return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
+ }
+ return nil
+}
diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go
index 068f001d8..11575d500 100644
--- a/client/internal/dns/server_test.go
+++ b/client/internal/dns/server_test.go
@@ -363,7 +363,13 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
- dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
+ dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
+ WgInterface: wgIface,
+ CustomAddress: "",
+ StatusRecorder: peer.NewRecorder("mgm"),
+ StateManager: nil,
+ DisableSys: false,
+ })
if err != nil {
t.Fatal(err)
}
@@ -473,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
- dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
+ dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
+ WgInterface: wgIface,
+ CustomAddress: "",
+ StatusRecorder: peer.NewRecorder("mgm"),
+ StateManager: nil,
+ DisableSys: false,
+ })
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@@ -575,7 +587,13 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
- dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false)
+ dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
+ WgInterface: &mocWGIface{},
+ CustomAddress: testCase.addrPort,
+ StatusRecorder: peer.NewRecorder("mgm"),
+ StateManager: nil,
+ DisableSys: false,
+ })
if err != nil {
t.Fatalf("%v", err)
}
diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go
index 89d637686..6ef0ab526 100644
--- a/client/internal/dns/service_memory.go
+++ b/client/internal/dns/service_memory.go
@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
type ServiceViaMemory struct {
diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go
index 071e3617a..c19e0acb5 100644
--- a/client/internal/dns/upstream.go
+++ b/client/internal/dns/upstream.go
@@ -33,9 +33,11 @@ func SetCurrentMTU(mtu uint16) {
}
const (
- UpstreamTimeout = 15 * time.Second
+ UpstreamTimeout = 4 * time.Second
+ // ClientTimeout is the timeout for the dns.Client.
+ // Set longer than UpstreamTimeout to ensure context timeout takes precedence
+ ClientTimeout = 5 * time.Second
- failsTillDeact = int32(5)
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
)
@@ -58,9 +60,7 @@ type upstreamResolverBase struct {
upstreamServers []netip.AddrPort
domain string
disabled bool
- failsCount atomic.Int32
successCount atomic.Int32
- failsTillDeact int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
@@ -79,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
domain: domain,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
- failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder,
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
- return fmt.Sprintf("upstream %s", u.upstreamServers)
+ return fmt.Sprintf("Upstream %s", u.upstreamServers)
}
// ID returns the unique handler ID
@@ -116,58 +115,102 @@ func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
- var err error
- defer func() {
- u.checkUpstreamFails(err)
- }()
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
+
+ u.prepareRequest(r)
+
+ if u.ctx.Err() != nil {
+ logger.Tracef("%s has been stopped", u)
+ return
+ }
+
+ if u.tryUpstreamServers(w, r, logger) {
+ return
+ }
+
+ u.writeErrorResponse(w, r, logger)
+}
+
+func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
+}
- select {
- case <-u.ctx.Done():
- logger.Tracef("%s has been stopped", u)
- return
- default:
+func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
+ timeout := u.upstreamTimeout
+ if len(u.upstreamServers) > 1 {
+ maxTotal := 5 * time.Second
+ minPerUpstream := 2 * time.Second
+ scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
+ if scaledTimeout > minPerUpstream {
+ timeout = scaledTimeout
+ } else {
+ timeout = minPerUpstream
+ }
}
for _, upstream := range u.upstreamServers {
- var rm *dns.Msg
- var t time.Duration
-
- func() {
- ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
- defer cancel()
- rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
- }()
-
- if err != nil {
- if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
- logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
- continue
- }
- logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
- continue
+ if u.queryUpstream(w, r, upstream, timeout, logger) {
+ return true
}
+ }
+ return false
+}
- if rm == nil || !rm.Response {
- logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
- continue
- }
+func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
+ var rm *dns.Msg
+ var t time.Duration
+ var err error
- u.successCount.Add(1)
- logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
+ var startTime time.Time
+ func() {
+ ctx, cancel := context.WithTimeout(u.ctx, timeout)
+ defer cancel()
+ startTime = time.Now()
+ rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
+ }()
- if err = w.WriteMsg(rm); err != nil {
- logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
- }
- // count the fails only if they happen sequentially
- u.failsCount.Store(0)
+ if err != nil {
+ u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
+ return false
+ }
+
+ if rm == nil || !rm.Response {
+ logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
+ return false
+ }
+
+ return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
+}
+
+func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
+ if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
+ logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return
}
- u.failsCount.Add(1)
+
+ elapsed := time.Since(startTime)
+ timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
+ if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
+ timeoutMsg += " " + peerInfo
+ }
+ timeoutMsg += fmt.Sprintf(" - error: %v", err)
+ logger.Warnf(timeoutMsg)
+}
+
+func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
+ u.successCount.Add(1)
+ logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
+
+ if err := w.WriteMsg(rm); err != nil {
+ logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
+ }
+ return true
+}
+
+func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
@@ -177,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}
-// checkUpstreamFails counts fails and disables or enables upstream resolving
-//
-// If fails count is greater that failsTillDeact, upstream resolving
-// will be disabled for reactivatePeriod, after that time period fails counter
-// will be reset and upstream will be reactivated.
-func (u *upstreamResolverBase) checkUpstreamFails(err error) {
- u.mutex.Lock()
- defer u.mutex.Unlock()
-
- if u.failsCount.Load() < u.failsTillDeact || u.disabled {
- return
- }
-
- select {
- case <-u.ctx.Done():
- return
- default:
- }
-
- u.disable(err)
-
- if u.statusRecorder == nil {
- return
- }
-
- u.statusRecorder.PublishEvent(
- proto.SystemEvent_WARNING,
- proto.SystemEvent_DNS,
- "All upstream servers failed (fail count exceeded)",
- "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
- map[string]string{"upstreams": u.upstreamServersString()},
- // TODO add domain meta
- )
-}
-
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
@@ -224,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() {
default:
}
- // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
- if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
+ // avoid probe if upstreams could resolve at least one query
+ if u.successCount.Load() > 0 {
return
}
@@ -312,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
- u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate()
u.disabled = false
@@ -416,3 +423,80 @@ func GenerateRequestID() string {
}
return hex.EncodeToString(bytes)
}
+
+// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
+func FormatPeerStatus(peerState *peer.State) string {
+ isConnected := peerState.ConnStatus == peer.StatusConnected
+ hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
+ time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
+
+ statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
+
+ switch {
+ case !isConnected:
+ statusInfo += " DISCONNECTED"
+ case !hasRecentHandshake:
+ statusInfo += " NO_RECENT_HANDSHAKE"
+ default:
+ statusInfo += " connected"
+ }
+
+ if !peerState.LastWireguardHandshake.IsZero() {
+ timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
+ statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
+ } else {
+ statusInfo += " no_handshake"
+ }
+
+ if peerState.Relayed {
+ statusInfo += " via_relay"
+ }
+
+ if peerState.Latency > 0 {
+ statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
+ }
+
+ return statusInfo
+}
+
+// findPeerForIP finds which peer handles the given IP address
+func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
+ if statusRecorder == nil {
+ return nil
+ }
+
+ fullStatus := statusRecorder.GetFullStatus()
+ var bestMatch *peer.State
+ var bestPrefixLen int
+
+ for _, peerState := range fullStatus.Peers {
+ routes := peerState.GetRoutes()
+ for route := range routes {
+ prefix, err := netip.ParsePrefix(route)
+ if err != nil {
+ continue
+ }
+
+ if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
+ peerStateCopy := peerState
+ bestMatch = &peerStateCopy
+ bestPrefixLen = prefix.Bits()
+ }
+ }
+ }
+
+ return bestMatch
+}
+
+func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
+ if u.statusRecorder == nil {
+ return ""
+ }
+
+ peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
+ if peerInfo == nil {
+ return ""
+ }
+
+ return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
+}
diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go
index ddbf84ae4..def281f28 100644
--- a/client/internal/dns/upstream_android.go
+++ b/client/internal/dns/upstream_android.go
@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
type upstreamResolver struct {
@@ -50,7 +50,9 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
}
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
- upstreamExchangeClient := &dns.Client{}
+ upstreamExchangeClient := &dns.Client{
+ Timeout: ClientTimeout,
+ }
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
@@ -72,10 +74,11 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
}
upstreamExchangeClient := &dns.Client{
- Dialer: dialer,
+ Dialer: dialer,
+ Timeout: timeout,
}
- return upstreamExchangeClient.Exchange(r, upstream)
+ return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go
index 317588a27..434e5880b 100644
--- a/client/internal/dns/upstream_general.go
+++ b/client/internal/dns/upstream_general.go
@@ -34,7 +34,10 @@ func newUpstreamResolver(
}
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
- return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
+ client := &dns.Client{
+ Timeout: ClientTimeout,
+ }
+ return ExchangeWithFallback(ctx, client, r, upstream)
}
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go
index 96b8bbb0f..eadcdd117 100644
--- a/client/internal/dns/upstream_ios.go
+++ b/client/internal/dns/upstream_ios.go
@@ -47,7 +47,9 @@ func newUpstreamResolver(
}
func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
- client := &dns.Client{}
+ client := &dns.Client{
+ Timeout: ClientTimeout,
+ }
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
@@ -110,7 +112,8 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura
},
}
client := &dns.Client{
- Dialer: dialer,
+ Dialer: dialer,
+ Timeout: dialTimeout,
}
return client, nil
}
diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go
index 51d870e2a..e1573e75e 100644
--- a/client/internal/dns/upstream_test.go
+++ b/client/internal/dns/upstream_test.go
@@ -124,29 +124,26 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
+ mockClient := &mockUpstreamResolver{
+ err: dns.ErrTime,
+ r: new(dns.Msg),
+ rtt: time.Millisecond,
+ }
+
resolver := &upstreamResolverBase{
- ctx: context.TODO(),
- upstreamClient: &mockUpstreamResolver{
- err: nil,
- r: new(dns.Msg),
- rtt: time.Millisecond,
- },
+ ctx: context.TODO(),
+ upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
- reactivatePeriod: reactivatePeriod,
- failsTillDeact: failsTillDeact,
+ reactivatePeriod: time.Microsecond * 100,
}
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
- resolver.failsTillDeact = 0
- resolver.reactivatePeriod = time.Microsecond * 100
-
- responseWriter := &test.MockResponseWriter{
- WriteMsgFunc: func(m *dns.Msg) error { return nil },
- }
failed := false
resolver.deactivate = func(error) {
failed = true
+ // After deactivation, make the mock client work again
+ mockClient.err = nil
}
reactivated := false
@@ -154,7 +151,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivated = true
}
- resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
+ resolver.ProbeAvailability()
if !failed {
t.Errorf("expected that resolving was deactivated")
@@ -173,11 +170,6 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
return
}
- if resolver.failsCount.Load() != 0 {
- t.Errorf("fails count after reactivation should be 0")
- return
- }
-
if resolver.disabled {
t.Errorf("should be enabled")
}
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 10f709dfa..61d109219 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -7,6 +7,7 @@ import (
"math/rand"
"net"
"net/netip"
+ "net/url"
"os"
"runtime"
"slices"
@@ -16,8 +17,8 @@ import (
"time"
"github.com/hashicorp/go-multierror"
- "github.com/pion/ice/v3"
- "github.com/pion/stun/v2"
+ "github.com/pion/ice/v4"
+ "github.com/pion/stun/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -27,10 +28,11 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
- "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
+ dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -165,7 +167,7 @@ type Engine struct {
wgInterface WGIface
- udpMux *bind.UniversalUDPMuxDefault
+ udpMux *udpmux.UniversalUDPMuxDefault
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
@@ -196,6 +198,10 @@ type Engine struct {
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
+
+ // WireGuard interface monitor
+ wgIfaceMonitor *WGIfaceMonitor
+ wgIfaceMonitorWg sync.WaitGroup
}
// Peer is an instance of the Connection Peer
@@ -344,13 +350,16 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}
+ // Stop WireGuard interface monitor and wait for it to exit
+ e.wgIfaceMonitorWg.Wait()
+
return nil
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
-func (e *Engine) Start() error {
+func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -406,6 +415,11 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer
+ // Populate DNS cache with NetbirdConfig and management URL for early resolution
+ if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil {
+ log.Warnf("failed to populate DNS cache: %v", err)
+ }
+
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Context: e.ctx,
PublicKey: e.config.WgPrivateKey.PublicKey().String(),
@@ -459,7 +473,7 @@ func (e *Engine) Start() error {
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
- UDPMux: e.udpMux.UDPMuxDefault,
+ UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
@@ -475,6 +489,22 @@ func (e *Engine) Start() error {
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
+
+ // monitor WireGuard interface lifecycle and restart engine on changes
+ e.wgIfaceMonitor = NewWGIfaceMonitor()
+ e.wgIfaceMonitorWg.Add(1)
+
+ go func() {
+ defer e.wgIfaceMonitorWg.Done()
+
+ if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
+ log.Infof("WireGuard interface monitor: %s, restarting engine", err)
+ e.restartEngine()
+ } else if err != nil {
+ log.Warnf("WireGuard interface monitor: %s", err)
+ }
+ }()
+
return nil
}
@@ -666,6 +696,30 @@ func (e *Engine) removePeer(peerKey string) error {
return nil
}
+// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response
+func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
+ if e.dnsServer == nil {
+ return nil
+ }
+
+ // Populate management URL if provided
+ if mgmtURL != nil {
+ if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil {
+ log.Warnf("failed to populate DNS cache with management URL: %v", err)
+ }
+ }
+
+ // Populate NetbirdConfig domains if provided
+ if netbirdConfig != nil {
+ serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig)
+ if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
+ return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err)
+ }
+ }
+
+ return nil
+}
+
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -697,6 +751,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err)
}
+ if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
+ log.Warnf("Failed to update DNS server config: %v", err)
+ }
+
// todo update signal
}
@@ -867,7 +925,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.EnableSSHRemotePortForwarding,
)
- // err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
@@ -878,7 +935,7 @@ func (e *Engine) receiveManagementEvents() {
}
log.Debugf("stopped receiving updates from Management Service")
}()
- log.Debugf("connecting to Management Service updates stream")
+ log.Infof("connecting to Management Service updates stream")
}
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
@@ -1253,7 +1310,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
- UDPMux: e.udpMux.UDPMuxDefault,
+ UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},
@@ -1515,7 +1572,14 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
default:
- dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
+
+ dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{
+ WgInterface: e.wgInterface,
+ CustomAddress: e.config.CustomDNSAddress,
+ StatusRecorder: e.statusRecorder,
+ StateManager: e.stateManager,
+ DisableSys: e.config.DisableDNS,
+ })
if err != nil {
return nil, err
}
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index ce805c776..e82a96abf 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -19,21 +19,18 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
+ wgdevice "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
- wgdevice "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun/netstack"
-
"github.com/netbirdio/management-integrations/integrations"
- "github.com/netbirdio/netbird/management/internals/server/config"
- "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface"
- "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -45,9 +42,12 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
+ "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -85,7 +85,7 @@ type MockWGIface struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface
- UpFunc func() (*bind.UniversalUDPMuxDefault, error)
+ UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
@@ -135,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
-func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
@@ -244,7 +244,7 @@ func TestEngine_SSH(t *testing.T) {
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
- err = engine.Start()
+ err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
@@ -440,7 +440,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
+ engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
}
}()
- err = engine.Start()
+ err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
@@ -1095,7 +1095,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
- err = engine.Start()
+ err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
@@ -1581,7 +1581,11 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
- ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
+
+ permissionsManager := permissions.NewManager(store)
+ peersManager := peers.NewManager(store, permissionsManager)
+
+ ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -1598,7 +1602,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
Return(&types.ExtraSettings{}, nil).
AnyTimes()
- permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go
index bf96153ea..690fdb7cc 100644
--- a/client/internal/iface_common.go
+++ b/client/internal/iface_common.go
@@ -9,9 +9,9 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
Name() string
Address() wgaddr.Address
ToInterface() *net.Interface
- Up() (*bind.UniversalUDPMuxDefault, error)
+ Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
diff --git a/client/internal/login.go b/client/internal/login.go
index ffabacf4a..28d45e49c 100644
--- a/client/internal/login.go
+++ b/client/internal/login.go
@@ -40,7 +40,7 @@ func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool,
return false, err
}
- _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
+ _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
@@ -69,14 +69,18 @@ func Login(ctx context.Context, config *profilemanager.Config, setupKey string,
return err
}
- serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
+ serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
+ if err != nil {
+ return err
+ }
+ } else if err != nil {
return err
}
- return err
+ return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
@@ -101,11 +105,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
return mgmClient, err
}
-func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) {
+func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
- return nil, err
+ return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
@@ -125,8 +129,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
)
- _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
- return serverKey, err
+ loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
+ return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go
index dbb4747a5..a4ffa3a25 100644
--- a/client/internal/netflow/conntrack/conntrack.go
+++ b/client/internal/netflow/conntrack/conntrack.go
@@ -14,7 +14,7 @@ import (
"github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
const defaultChannelSize = 100
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index a6cf3cd25..86e4596d4 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -6,12 +6,11 @@ import (
"math/rand"
"net"
"net/netip"
- "os"
"runtime"
"sync"
"time"
- "github.com/pion/ice/v3"
+ "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -174,7 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
- if os.Getenv("NB_FORCE_RELAY") != "true" {
+ if !isForceRelayed() {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go
new file mode 100644
index 000000000..32a458d00
--- /dev/null
+++ b/client/internal/peer/env.go
@@ -0,0 +1,14 @@
+package peer
+
+import (
+ "os"
+ "strings"
+)
+
+const (
+ EnvKeyNBForceRelay = "NB_FORCE_RELAY"
+)
+
+func isForceRelayed() bool {
+ return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
+}
diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go
index b9c9aa134..70850e6eb 100644
--- a/client/internal/peer/guard/ice_monitor.go
+++ b/client/internal/peer/guard/ice_monitor.go
@@ -6,7 +6,7 @@ import (
"sync"
"time"
- "github.com/pion/ice/v3"
+ "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go
index 3cbf74cfd..42eaea683 100644
--- a/client/internal/peer/handshaker.go
+++ b/client/internal/peer/handshaker.go
@@ -43,13 +43,6 @@ type OfferAnswer struct {
SessionID *ICESessionID
}
-func (oa *OfferAnswer) SessionIDString() string {
- if oa.SessionID == nil {
- return "unknown"
- }
- return oa.SessionID.String()
-}
-
type Handshaker struct {
mu sync.Mutex
log *log.Entry
@@ -57,7 +50,7 @@ type Handshaker struct {
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
- onNewOfferListeners []func(*OfferAnswer)
+ onNewOfferListeners []*OfferListener
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
@@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
}
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
- h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
+ l := NewOfferListener(offer)
+ h.onNewOfferListeners = append(h.onNewOfferListeners, l)
}
func (h *Handshaker) Listen(ctx context.Context) {
@@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue
}
for _, listener := range h.onNewOfferListeners {
- listener(&remoteOfferAnswer)
+ listener.Notify(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners {
- listener(&remoteOfferAnswer)
+ listener.Notify(&remoteOfferAnswer)
}
case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers")
diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go
new file mode 100644
index 000000000..e2d3f3f38
--- /dev/null
+++ b/client/internal/peer/handshaker_listener.go
@@ -0,0 +1,62 @@
+package peer
+
+import (
+ "sync"
+)
+
+type callbackFunc func(remoteOfferAnswer *OfferAnswer)
+
+func (oa *OfferAnswer) SessionIDString() string {
+ if oa.SessionID == nil {
+ return "unknown"
+ }
+ return oa.SessionID.String()
+}
+
+type OfferListener struct {
+ fn callbackFunc
+ running bool
+ latest *OfferAnswer
+ mu sync.Mutex
+}
+
+func NewOfferListener(fn callbackFunc) *OfferListener {
+ return &OfferListener{
+ fn: fn,
+ }
+}
+
+func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+
+ // Store the latest offer
+ o.latest = remoteOfferAnswer
+
+ // If already running, the running goroutine will pick up this latest value
+ if o.running {
+ return
+ }
+
+ // Start processing
+ o.running = true
+
+ // Process in a goroutine to avoid blocking the caller
+ go func(remoteOfferAnswer *OfferAnswer) {
+ for {
+ o.fn(remoteOfferAnswer)
+
+ o.mu.Lock()
+ if o.latest == nil {
+ // No more work to do
+ o.running = false
+ o.mu.Unlock()
+ return
+ }
+ remoteOfferAnswer = o.latest
+ // Clear the latest to mark it as being processed
+ o.latest = nil
+ o.mu.Unlock()
+ }
+ }(remoteOfferAnswer)
+}
diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go
new file mode 100644
index 000000000..8363741a5
--- /dev/null
+++ b/client/internal/peer/handshaker_listener_test.go
@@ -0,0 +1,39 @@
+package peer
+
+import (
+ "testing"
+ "time"
+)
+
+func Test_newOfferListener(t *testing.T) {
+ dummyOfferAnswer := &OfferAnswer{}
+ runChan := make(chan struct{}, 10)
+
+ longRunningFn := func(remoteOfferAnswer *OfferAnswer) {
+ time.Sleep(1 * time.Second)
+ runChan <- struct{}{}
+ }
+
+ hl := NewOfferListener(longRunningFn)
+
+ hl.Notify(dummyOfferAnswer)
+ hl.Notify(dummyOfferAnswer)
+ hl.Notify(dummyOfferAnswer)
+
+ // Wait for exactly 2 callbacks
+ for i := 0; i < 2; i++ {
+ select {
+ case <-runChan:
+ case <-time.After(3 * time.Second):
+ t.Fatal("Timeout waiting for callback")
+ }
+ }
+
+ // Verify no additional callbacks happen
+ select {
+ case <-runChan:
+ t.Fatal("Unexpected additional callback")
+ case <-time.After(100 * time.Millisecond):
+ t.Log("Correctly received exactly 2 callbacks")
+ }
+}
diff --git a/client/internal/peer/ice/StunTurn.go b/client/internal/peer/ice/StunTurn.go
index 63ee8c713..a389f5444 100644
--- a/client/internal/peer/ice/StunTurn.go
+++ b/client/internal/peer/ice/StunTurn.go
@@ -3,7 +3,7 @@ package ice
import (
"sync/atomic"
- "github.com/pion/stun/v2"
+ "github.com/pion/stun/v3"
)
type StunTurn atomic.Value
diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go
index 58c1bf634..e80c98884 100644
--- a/client/internal/peer/ice/agent.go
+++ b/client/internal/peer/ice/agent.go
@@ -4,7 +4,7 @@ import (
"sync"
"time"
- "github.com/pion/ice/v3"
+ "github.com/pion/ice/v4"
"github.com/pion/logging"
"github.com/pion/randutil"
log "github.com/sirupsen/logrus"
diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go
index dd854a605..dd5d67403 100644
--- a/client/internal/peer/ice/config.go
+++ b/client/internal/peer/ice/config.go
@@ -1,7 +1,7 @@
package ice
import (
- "github.com/pion/ice/v3"
+ "github.com/pion/ice/v4"
)
type Config struct {
diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go
index ca1d421a5..b28906625 100644
--- a/client/internal/peer/signaler.go
+++ b/client/internal/peer/signaler.go
@@ -1,7 +1,7 @@
package peer
import (
- "github.com/pion/ice/v3"
+ "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go
index 218872c15..0ed200fda 100644
--- a/client/internal/peer/wg_watcher.go
+++ b/client/internal/peer/wg_watcher.go
@@ -30,9 +30,10 @@ type WGWatcher struct {
peerKey string
stateDump *stateDump
- ctx context.Context
- ctxCancel context.CancelFunc
- ctxLock sync.Mutex
+ ctx context.Context
+ ctxCancel context.CancelFunc
+ ctxLock sync.Mutex
+ enabledTime time.Time
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
+ w.enabledTime = time.Now()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
@@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
onDisconnectedFn()
return
}
+ if lastHandshake.IsZero() {
+ elapsed := handshake.Sub(w.enabledTime).Seconds()
+ w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
+ }
+
lastHandshake = *handshake
resetTime := time.Until(handshake.Add(checkPeriod))
diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go
index 4f00af829..eb886a4d3 100644
--- a/client/internal/peer/worker_ice.go
+++ b/client/internal/peer/worker_ice.go
@@ -8,12 +8,11 @@ import (
"sync"
"time"
- "github.com/pion/ice/v3"
- "github.com/pion/stun/v2"
+ "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
- "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
@@ -55,10 +54,6 @@ type WorkerICE struct {
sessionID ICESessionID
muxAgent sync.Mutex
- StunTurn []*stun.URI
-
- sentExtraSrflx bool
-
localUfrag string
localPwd string
@@ -122,7 +117,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
- // todo consider to switch to Relay connection while establishing a new ICE connection
}
var preferredCandidateTypes []ice.CandidateType
@@ -140,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.muxAgent.Unlock()
return
}
- w.sentExtraSrflx = false
w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
@@ -167,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
w.log.Errorf("error while handling remote candidate")
return
}
+
+ if shouldAddExtraCandidate(candidate) {
+ // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
+ // this is useful when network has an existing port forwarding rule for the wireguard port and this peer
+ extraSrflx, err := extraSrflxCandidate(candidate)
+ if err != nil {
+ w.log.Errorf("failed creating extra server reflexive candidate %s", err)
+ return
+ }
+
+ if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
+ w.log.Errorf("error while handling remote candidate")
+ return
+ }
+ }
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
@@ -210,14 +218,12 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
return nil, err
}
- if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
+ if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
+ w.onICESelectedCandidatePair(agent, c1, c2)
+ }); err != nil {
return nil, err
}
- if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil {
- return nil, fmt.Errorf("failed setting binding response callback: %w", err)
- }
-
return agent, nil
}
@@ -253,6 +259,11 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.closeAgent(agent, w.agentDialerCancel)
return
}
+ if pair == nil {
+ w.log.Warnf("selected candidate pair is nil, cannot proceed")
+ w.closeAgent(agent, w.agentDialerCancel)
+ return
+ }
if !isRelayCandidate(pair.Local) {
// dynamically set remote WireGuard port if other side specified a different one from the default one
@@ -327,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
return
}
- mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
+ mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
if !ok {
w.log.Warn("invalid udp mux conversion")
return
@@ -354,31 +365,23 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
-
- if !w.shouldSendExtraSrflxCandidate(candidate) {
- return
- }
-
- // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
- // this is useful when network has an existing port forwarding rule for the wireguard port and this peer
- extraSrflx, err := extraSrflxCandidate(candidate)
- if err != nil {
- w.log.Errorf("failed creating extra server reflexive candidate %s", err)
- return
- }
- w.sentExtraSrflx = true
-
- go func() {
- err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
- if err != nil {
- w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
- }
- }()
}
-func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
+func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key)
+
+ pairStat, ok := agent.GetSelectedCandidatePairStats()
+ if !ok {
+ w.log.Warnf("failed to get selected candidate pair stats")
+ return
+ }
+
+ duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
+ if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
+ w.log.Debugf("failed to update latency for peer: %s", err)
+ return
+ }
}
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
@@ -388,7 +391,10 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
- case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
+ case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
+ // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
+ // notify the conn.onICEStateDisconnected changes to update the current used priority
+
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
@@ -400,32 +406,34 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
}
}
-func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) {
- if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil {
- w.log.Debugf("failed to update latency for peer: %s", err)
- return
- }
-}
-
-func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
- if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
- return true
- }
- return false
-}
-
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
- isControlling := w.config.LocalKey > w.config.Key
- if isControlling {
- return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
+ if isController(w.config) {
+ return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
}
+func shouldAddExtraCandidate(candidate ice.Candidate) bool {
+ if candidate.Type() != ice.CandidateTypeServerReflexive {
+ return false
+ }
+
+ if candidate.Port() == candidate.RelatedAddress().Port {
+ return false
+ }
+
+ // in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
+ // in newer version we generate locally the extra candidate
+ if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
+ return false
+ }
+ return true
+}
+
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
- return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
+ ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
@@ -433,6 +441,21 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
+ if err != nil {
+ return nil, err
+ }
+
+ for _, e := range candidate.Extensions() {
+ // overwrite the original candidate ID with the new one to avoid candidate duplication
+ if e.Key == ice.ExtensionKeyCandidateID {
+ e.Value = candidate.ID()
+ }
+ if err := ec.AddExtension(e); err != nil {
+ return nil, err
+ }
+ }
+
+ return ec, nil
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go
index 6e1f83a9a..fa208716f 100644
--- a/client/internal/relay/relay.go
+++ b/client/internal/relay/relay.go
@@ -7,12 +7,12 @@ import (
"sync"
"time"
- "github.com/pion/stun/v2"
+ "github.com/pion/stun/v3"
"github.com/pion/turn/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
// ProbeResult holds the info about the result of a relay probe request
diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go
index ba27df654..9069cdcc5 100644
--- a/client/internal/routemanager/dnsinterceptor/handler.go
+++ b/client/internal/routemanager/dnsinterceptor/handler.go
@@ -2,11 +2,13 @@ package dnsinterceptor
import (
"context"
+ "errors"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
+ "time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
@@ -26,6 +28,8 @@ import (
"github.com/netbirdio/netbird/route"
)
+const dnsTimeout = 8 * time.Second
+
type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface {
@@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
- client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
+ client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
@@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
- reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
+ ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
+ defer cancel()
+
+ startTime := time.Now()
+ reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil {
- logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
+ if errors.Is(err, context.DeadlineExceeded) {
+ elapsed := time.Since(startTime)
+ peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
+ logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
+ elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
+ } else {
+ logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
+ }
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
@@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
}
return
}
+
+func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
+ if d.statusRecorder == nil {
+ return ""
+ }
+
+ peerState, err := d.statusRecorder.GetPeer(peerKey)
+ if err != nil {
+ return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
+ }
+
+ return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
+}
diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go
index a6775c45a..04513bbe4 100644
--- a/client/internal/routemanager/manager.go
+++ b/client/internal/routemanager/manager.go
@@ -36,9 +36,9 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
+ nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
- nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
+ if runtime.GOOS == "windows" && config.WGInterface != nil {
+ nbnet.SetVPNInterfaceName(config.WGInterface.Name())
+ }
+
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
return nil
}
- if err := m.sysOps.CleanupRouting(nil); err != nil {
+ if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
ips := resolveURLsToIPs(initialAddresses)
- if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
+ if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
return fmt.Errorf("setup routing: %w", err)
}
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
- if err := m.sysOps.CleanupRouting(stateManager); err != nil {
+ if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
+
+ if runtime.GOOS == "windows" {
+ nbnet.SetVPNInterfaceName("")
+ }
}
m.mux.Lock()
diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go
index a375ce832..7cb8dae93 100644
--- a/client/internal/routemanager/systemops/systemops_android.go
+++ b/client/internal/routemanager/systemops/systemops_android.go
@@ -12,11 +12,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
-func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
+func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
return nil
}
-func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
+func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
return nil
}
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index 128afa2a5..26a548634 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -3,7 +3,6 @@
package systemops
import (
- "context"
"errors"
"fmt"
"net"
@@ -22,7 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
- nbnet "github.com/netbirdio/netbird/util/net"
+ "github.com/netbirdio/netbird/client/net/hooks"
)
const localSubnetsCacheTTL = 15 * time.Minute
@@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil
}
- // TODO: Remove hooks selectively
- nbnet.RemoveDialerHooks()
- nbnet.RemoveListenerHooks()
+ hooks.RemoveWriteHooks()
+ hooks.RemoveCloseHooks()
+ hooks.RemoveAddressRemoveHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
@@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
- beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
- prefix, err := util.GetPrefixFromIP(ip)
- if err != nil {
- return fmt.Errorf("convert ip to prefix: %w", err)
- }
-
+ beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
@@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil
}
- afterHook := func(connID nbnet.ConnectionID) error {
+ afterHook := func(connID hooks.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
@@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
var merr *multierror.Error
for _, ip := range initAddresses {
- if err := beforeHook("init", ip); err != nil {
- merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
+ prefix, err := util.GetPrefixFromIP(ip)
+ if err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
+ continue
+ }
+ if err := beforeHook("init", prefix); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
}
}
- nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
- if ctx.Err() != nil {
- return ctx.Err()
- }
+ hooks.AddWriteHook(beforeHook)
+ hooks.AddCloseHook(afterHook)
- var merr *multierror.Error
- for _, ip := range resolvedIPs {
- merr = multierror.Append(merr, beforeHook(connID, ip.IP))
- }
- return nberrors.FormatErrorOrNil(merr)
- })
-
- nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
- return afterHook(connID)
- })
-
- nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
- return beforeHook(connID, ip.IP)
- })
-
- nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
- return afterHook(connID)
- })
-
- nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
+ hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go
index c1c1182bc..32ea38a7a 100644
--- a/client/internal/routemanager/systemops/systemops_generic_test.go
+++ b/client/internal/routemanager/systemops/systemops_generic_test.go
@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
type dialer interface {
@@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
- err := r.SetupRouting(nil, nil)
+ advancedRouting := nbnet.AdvancedRouting()
+ err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
t.Cleanup(func() {
- assert.NoError(t, r.CleanupRouting(nil))
+ assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
intf, err := net.InterfaceByName(wgInterface.Name())
@@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
- err := r.SetupRouting(nil, nil)
+ advancedRouting := nbnet.AdvancedRouting()
+ err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
t.Cleanup(func() {
- assert.NoError(t, r.CleanupRouting(nil))
+ assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) {
})
r := NewSysOps(wgInterface, nil)
- err := r.SetupRouting(nil, nil)
+ advancedRouting := nbnet.AdvancedRouting()
+ err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
- assert.NoError(t, r.CleanupRouting(nil))
+ assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
index, err := net.InterfaceByName(wgInterface.Name())
diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go
index 10356eae0..99a363371 100644
--- a/client/internal/routemanager/systemops/systemops_ios.go
+++ b/client/internal/routemanager/systemops/systemops_ios.go
@@ -12,14 +12,14 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
-func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
+func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil
}
-func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
+func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
r.mu.Lock()
defer r.mu.Unlock()
diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go
index c0cef94ba..bd10f131f 100644
--- a/client/internal/routemanager/systemops/systemops_linux.go
+++ b/client/internal/routemanager/systemops/systemops_linux.go
@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
// IPRule contains IP rule information for debugging
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
-func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
- if !nbnet.AdvancedRouting() {
+func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
+ if !advancedRouting {
log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager)
}
defer func() {
if err != nil {
- if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
+ if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
-func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
- if !nbnet.AdvancedRouting() {
+func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
+ if !advancedRouting {
return r.cleanupRefCounter(stateManager)
}
diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go
index f165f7779..d43c2d5bf 100644
--- a/client/internal/routemanager/systemops/systemops_unix.go
+++ b/client/internal/routemanager/systemops/systemops_unix.go
@@ -20,11 +20,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
-func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
+func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager)
}
-func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
+func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
return r.cleanupRefCounter(stateManager)
}
diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go
index ad37f611f..959c697e4 100644
--- a/client/internal/routemanager/systemops/systemops_unix_test.go
+++ b/client/internal/routemanager/systemops/systemops_unix_test.go
@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
type PacketExpectation struct {
diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go
index 4f836897b..7bce6af80 100644
--- a/client/internal/routemanager/systemops/systemops_windows.go
+++ b/client/internal/routemanager/systemops/systemops_windows.go
@@ -8,6 +8,7 @@ import (
"net/netip"
"os"
"runtime/debug"
+ "sort"
"strconv"
"sync"
"syscall"
@@ -19,9 +20,16 @@ import (
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
-const InfiniteLifetime = 0xffffffff
+func init() {
+ nbnet.GetBestInterfaceFunc = GetBestInterface
+}
+
+const (
+ InfiniteLifetime = 0xffffffff
+)
type RouteUpdateType int
@@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct {
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
}
+// candidateRoute represents a potential route for selection during route lookup
+type candidateRoute struct {
+ interfaceIndex uint32
+ prefixLength uint8
+ routeMetric uint32
+ interfaceMetric int
+}
+
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
type IP_ADDRESS_PREFIX struct {
Prefix SOCKADDR_INET
@@ -177,11 +193,20 @@ const (
RouteDeleted
)
-func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
+func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
+ if advancedRouting {
+ return nil
+ }
+
+ log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager)
}
-func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
+func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
+ if advancedRouting {
+ return nil
+ }
+
return r.cleanupRefCounter(stateManager)
}
@@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) {
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
if table != nil {
- ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
- if ret != 0 {
- log.Warnf("FreeMibTable failed with return code: %d", ret)
- }
+ _, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
}
}
@@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute {
entryPtr := basePtr + uintptr(i)*entrySize
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
- detailed := buildWindowsDetailedRoute(entry)
- if detailed != nil {
+ if detailed := buildWindowsDetailedRoute(entry); detailed != nil {
detailedRoutes = append(detailedRoutes, *detailed)
}
}
@@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
return ip
}
+// parseCandidatesFromTable extracts all matching candidate routes from the routing table
+func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute {
+ var candidates []candidateRoute
+ entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{})
+ basePtr := uintptr(unsafe.Pointer(&table.Table[0]))
+
+ for i := uint32(0); i < table.NumEntries; i++ {
+ entryPtr := basePtr + uintptr(i)*entrySize
+ entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
+
+ if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil {
+ candidates = append(candidates, *candidate)
+ }
+ }
+
+ return candidates
+}
+
+// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry
+// Returns nil if the route doesn't match the destination or should be skipped
+func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute {
+ if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex {
+ return nil
+ }
+
+ destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex))
+ if !destPrefix.IsValid() || !destPrefix.Contains(dest) {
+ return nil
+ }
+
+ interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family)
+
+ return &candidateRoute{
+ interfaceIndex: entry.InterfaceIndex,
+ prefixLength: entry.DestinationPrefix.PrefixLength,
+ routeMetric: entry.Metric,
+ interfaceMetric: interfaceMetric,
+ }
+}
+
// getInterfaceMetric retrieves the interface metric for a given interface and address family
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
if interfaceIndex == 0 {
@@ -821,6 +882,76 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int {
return int(ipInterfaceRow.Metric)
}
+// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric
+func sortRouteCandidates(candidates []candidateRoute) {
+ sort.Slice(candidates, func(i, j int) bool {
+ if candidates[i].prefixLength != candidates[j].prefixLength {
+ return candidates[i].prefixLength > candidates[j].prefixLength
+ }
+ if candidates[i].routeMetric != candidates[j].routeMetric {
+ return candidates[i].routeMetric < candidates[j].routeMetric
+ }
+ return candidates[i].interfaceMetric < candidates[j].interfaceMetric
+ })
+}
+
+// GetBestInterface finds the best interface for reaching a destination,
+// excluding the VPN interface to avoid routing loops.
+//
+// Route selection priority:
+// 1. Longest prefix match (most specific route)
+// 2. Lowest route metric
+// 3. Lowest interface metric
+func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
+ var skipInterfaceIndex int
+ if vpnIntf != "" {
+ if iface, err := net.InterfaceByName(vpnIntf); err == nil {
+ skipInterfaceIndex = iface.Index
+ } else {
+ // not critical, if we cannot get ahold of the interface then we won't need to skip it
+ log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err)
+ }
+ }
+
+ table, err := getWindowsRoutingTable()
+ if err != nil {
+ return nil, fmt.Errorf("get routing table: %w", err)
+ }
+ defer freeWindowsRoutingTable(table)
+
+ candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex)
+
+ if len(candidates) == 0 {
+ return nil, fmt.Errorf("no route to %s", dest)
+ }
+
+ // Sort routes: prefix length -> route metric -> interface metric
+ sortRouteCandidates(candidates)
+
+ for _, candidate := range candidates {
+ iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex))
+ if err != nil {
+ log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err)
+ continue
+ }
+
+ if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
+ continue
+ }
+
+ if iface.Flags&net.FlagUp == 0 {
+ log.Debugf("interface %s is down, trying next route", iface.Name)
+ continue
+ }
+
+ log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d",
+ dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric)
+ return iface, nil
+ }
+
+ return nil, fmt.Errorf("no usable interface found for %s", dest)
+}
+
// formatRouteAge formats the route age in seconds to a human-readable string
func formatRouteAge(ageSeconds uint32) string {
if ageSeconds == 0 {
diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go
index 523bd0b0d..3561adec4 100644
--- a/client/internal/routemanager/systemops/systemops_windows_test.go
+++ b/client/internal/routemanager/systemops/systemops_windows_test.go
@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
var (
diff --git a/client/internal/routemanager/util/ip.go b/client/internal/routemanager/util/ip.go
index ac5a48e37..57ea32f69 100644
--- a/client/internal/routemanager/util/ip.go
+++ b/client/internal/routemanager/util/ip.go
@@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
if !ok {
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
}
+
addr = addr.Unmap()
-
- var prefixLength int
- switch {
- case addr.Is4():
- prefixLength = 32
- case addr.Is6():
- prefixLength = 128
- default:
- return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
- }
-
- prefix := netip.PrefixFrom(addr, prefixLength)
+ prefix := netip.PrefixFrom(addr, addr.BitLen())
return prefix, nil
}
diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go
index e80adb42b..8961eaa69 100644
--- a/client/internal/stdnet/dialer.go
+++ b/client/internal/stdnet/dialer.go
@@ -5,7 +5,7 @@ import (
"github.com/pion/transport/v3"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
// Dial connects to the address on the named network.
diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go
index 9ce0a5556..d3be1896f 100644
--- a/client/internal/stdnet/listener.go
+++ b/client/internal/stdnet/listener.go
@@ -6,7 +6,7 @@ import (
"github.com/pion/transport/v3"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
// ListenPacket listens for incoming packets on the given network and address.
diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go
index 171cc42cb..4b031c05c 100644
--- a/client/internal/stdnet/stdnet.go
+++ b/client/internal/stdnet/stdnet.go
@@ -40,7 +40,7 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
if netstack.IsEnabled() {
n.iFaceDiscover = pionDiscover{}
} else {
- newMobileIFaceDiscover(iFaceDiscover)
+ n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover)
}
return n, n.UpdateInterfaces()
}
diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go
new file mode 100644
index 000000000..78d70c15b
--- /dev/null
+++ b/client/internal/wg_iface_monitor.go
@@ -0,0 +1,98 @@
+package internal
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "runtime"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
+// if the interface is deleted externally while the engine is running.
+type WGIfaceMonitor struct {
+ done chan struct{}
+}
+
+// NewWGIfaceMonitor creates a new WGIfaceMonitor instance.
+func NewWGIfaceMonitor() *WGIfaceMonitor {
+ return &WGIfaceMonitor{
+ done: make(chan struct{}),
+ }
+}
+
+// Start begins monitoring the WireGuard interface.
+// It relies on the provided context cancellation to stop.
+func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
+ defer close(m.done)
+
+ // Skip on mobile platforms as they handle interface lifecycle differently
+ if runtime.GOOS == "android" || runtime.GOOS == "ios" {
+ log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS)
+ return false, errors.New("not supported on mobile platforms")
+ }
+
+ if ifaceName == "" {
+ log.Debugf("Interface monitor: empty interface name, skipping monitor")
+ return false, errors.New("empty interface name")
+ }
+
+ // Get initial interface index to track the specific interface instance
+ expectedIndex, err := getInterfaceIndex(ifaceName)
+ if err != nil {
+ log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName)
+ return false, fmt.Errorf("interface %s not found: %w", ifaceName, err)
+ }
+
+ log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
+
+ ticker := time.NewTicker(2 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ log.Infof("Interface monitor: stopped for %s", ifaceName)
+ return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
+ case <-ticker.C:
+ currentIndex, err := getInterfaceIndex(ifaceName)
+ if err != nil {
+ // Interface was deleted
+ log.Infof("Interface monitor: %s deleted", ifaceName)
+ return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
+ }
+
+ // Check if interface index changed (interface was recreated)
+ if currentIndex != expectedIndex {
+ log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
+ ifaceName, expectedIndex, currentIndex)
+ return true, nil
+ }
+ }
+ }
+
+}
+
+// getInterfaceIndex returns the index of a network interface by name.
+// Returns an error if the interface is not found.
+func getInterfaceIndex(name string) (int, error) {
+ if name == "" {
+ return 0, fmt.Errorf("empty interface name")
+ }
+ ifi, err := net.InterfaceByName(name)
+ if err != nil {
+ // Check if it's specifically a "not found" error
+ if errors.Is(err, &net.OpError{}) {
+ // On some systems, this might be a "not found" error
+ return 0, fmt.Errorf("interface not found: %w", err)
+ }
+ return 0, fmt.Errorf("failed to lookup interface: %w", err)
+ }
+ if ifi == nil {
+ return 0, fmt.Errorf("interface not found")
+ }
+ return ifi.Index, nil
+}
diff --git a/client/net/conn.go b/client/net/conn.go
new file mode 100644
index 000000000..918e7f628
--- /dev/null
+++ b/client/net/conn.go
@@ -0,0 +1,49 @@
+//go:build !ios
+
+package net
+
+import (
+ "io"
+ "net"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/net/hooks"
+)
+
+// Conn wraps a net.Conn to override the Close method
+type Conn struct {
+ net.Conn
+ ID hooks.ConnectionID
+}
+
+// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
+// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
+func (c *Conn) Close() error {
+ return closeConn(c.ID, c.Conn)
+}
+
+// TCPConn wraps net.TCPConn to override its Close method to include hook functionality.
+type TCPConn struct {
+ *net.TCPConn
+ ID hooks.ConnectionID
+}
+
+// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
+func (c *TCPConn) Close() error {
+ return closeConn(c.ID, c.TCPConn)
+}
+
+// closeConn is a helper function to close connections and execute close hooks.
+func closeConn(id hooks.ConnectionID, conn io.Closer) error {
+ err := conn.Close()
+
+ closeHooks := hooks.GetCloseHooks()
+ for _, hook := range closeHooks {
+ if err := hook(id); err != nil {
+ log.Errorf("Error executing close hook: %v", err)
+ }
+ }
+
+ return err
+}
diff --git a/client/net/dial.go b/client/net/dial.go
new file mode 100644
index 000000000..041a00e5d
--- /dev/null
+++ b/client/net/dial.go
@@ -0,0 +1,82 @@
+//go:build !ios
+
+package net
+
+import (
+ "fmt"
+ "net"
+ "sync"
+
+ "github.com/pion/transport/v3"
+ log "github.com/sirupsen/logrus"
+)
+
+func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
+ if CustomRoutingDisabled() {
+ return net.DialUDP(network, laddr, raddr)
+ }
+
+ dialer := NewDialer()
+ dialer.LocalAddr = laddr
+
+ conn, err := dialer.Dial(network, raddr.String())
+ if err != nil {
+ return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
+ }
+
+ switch c := conn.(type) {
+ case *net.UDPConn:
+ // Advanced routing: plain connection
+ return c, nil
+ case *Conn:
+ // Legacy routing: wrapped connection preserves close hooks
+ udpConn, ok := c.Conn.(*net.UDPConn)
+ if !ok {
+ if err := conn.Close(); err != nil {
+ log.Errorf("Failed to close connection: %v", err)
+ }
+ return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn)
+ }
+ return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
+ }
+
+ if err := conn.Close(); err != nil {
+ log.Errorf("failed to close connection: %v", err)
+ }
+ return nil, fmt.Errorf("unexpected connection type: %T", conn)
+}
+
+func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
+ if CustomRoutingDisabled() {
+ return net.DialTCP(network, laddr, raddr)
+ }
+
+ dialer := NewDialer()
+ dialer.LocalAddr = laddr
+
+ conn, err := dialer.Dial(network, raddr.String())
+ if err != nil {
+ return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
+ }
+
+ switch c := conn.(type) {
+ case *net.TCPConn:
+ // Advanced routing: plain connection
+ return c, nil
+ case *Conn:
+ // Legacy routing: wrapped connection preserves close hooks
+ tcpConn, ok := c.Conn.(*net.TCPConn)
+ if !ok {
+ if err := conn.Close(); err != nil {
+ log.Errorf("Failed to close connection: %v", err)
+ }
+ return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn)
+ }
+ return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
+ }
+
+ if err := conn.Close(); err != nil {
+ log.Errorf("failed to close connection: %v", err)
+ }
+ return nil, fmt.Errorf("unexpected connection type: %T", conn)
+}
diff --git a/util/net/dial_ios.go b/client/net/dial_ios.go
similarity index 100%
rename from util/net/dial_ios.go
rename to client/net/dial_ios.go
diff --git a/util/net/dialer.go b/client/net/dialer.go
similarity index 99%
rename from util/net/dialer.go
rename to client/net/dialer.go
index 0786c667e..29bec05a7 100644
--- a/util/net/dialer.go
+++ b/client/net/dialer.go
@@ -16,6 +16,5 @@ func NewDialer() *Dialer {
Dialer: &net.Dialer{},
}
dialer.init()
-
return dialer
}
diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go
new file mode 100644
index 000000000..2e1eb53d8
--- /dev/null
+++ b/client/net/dialer_dial.go
@@ -0,0 +1,87 @@
+//go:build !ios
+
+package net
+
+import (
+ "context"
+ "fmt"
+ "net"
+
+ "github.com/hashicorp/go-multierror"
+ log "github.com/sirupsen/logrus"
+
+ nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/client/internal/routemanager/util"
+ "github.com/netbirdio/netbird/client/net/hooks"
+)
+
+// DialContext wraps the net.Dialer's DialContext method to use the custom connection
+func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ log.Debugf("Dialing %s %s", network, address)
+
+ if CustomRoutingDisabled() || AdvancedRouting() {
+ return d.Dialer.DialContext(ctx, network, address)
+ }
+
+ connID := hooks.GenerateConnID()
+ if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil {
+ log.Errorf("Failed to call dialer hooks: %v", err)
+ }
+
+ conn, err := d.Dialer.DialContext(ctx, network, address)
+ if err != nil {
+ return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
+ }
+
+ // Wrap the connection in Conn to handle Close with hooks
+ return &Conn{Conn: conn, ID: connID}, nil
+}
+
+// Dial wraps the net.Dialer's Dial method to use the custom connection
+func (d *Dialer) Dial(network, address string) (net.Conn, error) {
+ return d.DialContext(context.Background(), network, address)
+}
+
+func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error {
+ if ctx.Err() != nil {
+ return ctx.Err()
+ }
+
+ writeHooks := hooks.GetWriteHooks()
+ if len(writeHooks) == 0 {
+ return nil
+ }
+
+ host, _, err := net.SplitHostPort(address)
+ if err != nil {
+ return fmt.Errorf("split host and port: %w", err)
+ }
+
+ resolver := customResolver
+ if resolver == nil {
+ resolver = net.DefaultResolver
+ }
+
+ ips, err := resolver.LookupIPAddr(ctx, host)
+ if err != nil {
+ return fmt.Errorf("failed to resolve address %s: %w", address, err)
+ }
+
+ log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
+
+ var merr *multierror.Error
+ for _, ip := range ips {
+ prefix, err := util.GetPrefixFromIP(ip.IP)
+ if err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err))
+ continue
+ }
+ for _, hook := range writeHooks {
+ if err := hook(connID, prefix); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err))
+ }
+ }
+ }
+
+ return nberrors.FormatErrorOrNil(merr)
+}
diff --git a/util/net/dialer_init_android.go b/client/net/dialer_init_android.go
similarity index 100%
rename from util/net/dialer_init_android.go
rename to client/net/dialer_init_android.go
diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go
new file mode 100644
index 000000000..18ebc6ad1
--- /dev/null
+++ b/client/net/dialer_init_generic.go
@@ -0,0 +1,7 @@
+//go:build !linux && !windows
+
+package net
+
+func (d *Dialer) init() {
+ // implemented on Linux, Android, and Windows only
+}
diff --git a/util/net/dialer_init_linux.go b/client/net/dialer_init_linux.go
similarity index 100%
rename from util/net/dialer_init_linux.go
rename to client/net/dialer_init_linux.go
diff --git a/client/net/dialer_init_windows.go b/client/net/dialer_init_windows.go
new file mode 100644
index 000000000..6eefe5b1e
--- /dev/null
+++ b/client/net/dialer_init_windows.go
@@ -0,0 +1,5 @@
+package net
+
+func (d *Dialer) init() {
+ d.Dialer.Control = applyUnicastIFToSocket
+}
diff --git a/util/net/env.go b/client/net/env.go
similarity index 94%
rename from util/net/env.go
rename to client/net/env.go
index 32425665d..8f326ca88 100644
--- a/util/net/env.go
+++ b/client/net/env.go
@@ -11,6 +11,7 @@ import (
const (
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
+ envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
)
// CustomRoutingDisabled returns true if custom routing is disabled.
diff --git a/client/net/env_android.go b/client/net/env_android.go
new file mode 100644
index 000000000..9d89951a1
--- /dev/null
+++ b/client/net/env_android.go
@@ -0,0 +1,24 @@
+//go:build android
+
+package net
+
+// Init initializes the network environment for Android
+func Init() {
+ // No initialization needed on Android
+}
+
+// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
+// Always returns true on Android since we cannot handle routes dynamically.
+func AdvancedRouting() bool {
+ return true
+}
+
+// SetVPNInterfaceName is a no-op on Android
+func SetVPNInterfaceName(name string) {
+ // No-op on Android - not needed for Android VPN service
+}
+
+// GetVPNInterfaceName returns empty string on Android
+func GetVPNInterfaceName() string {
+ return ""
+}
diff --git a/client/net/env_generic.go b/client/net/env_generic.go
new file mode 100644
index 000000000..f467930c3
--- /dev/null
+++ b/client/net/env_generic.go
@@ -0,0 +1,23 @@
+//go:build !linux && !windows && !android
+
+package net
+
+// Init initializes the network environment (no-op on non-Linux/Windows platforms)
+func Init() {
+ // No-op on non-Linux/Windows platforms
+}
+
+// AdvancedRouting returns false on non-Linux/Windows platforms
+func AdvancedRouting() bool {
+ return false
+}
+
+// SetVPNInterfaceName is a no-op on non-Windows platforms
+func SetVPNInterfaceName(name string) {
+ // No-op on non-Windows platforms
+}
+
+// GetVPNInterfaceName returns empty string on non-Windows platforms
+func GetVPNInterfaceName() string {
+ return ""
+}
diff --git a/util/net/env_linux.go b/client/net/env_linux.go
similarity index 86%
rename from util/net/env_linux.go
rename to client/net/env_linux.go
index 3159f6462..82d9a74a8 100644
--- a/util/net/env_linux.go
+++ b/client/net/env_linux.go
@@ -17,8 +17,7 @@ import (
const (
// these have the same effect, skip socket env supported for backward compatibility
- envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
- envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
+ envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
)
var advancedRoutingSupported bool
@@ -27,6 +26,7 @@ func Init() {
advancedRoutingSupported = checkAdvancedRoutingSupport()
}
+// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
func AdvancedRouting() bool {
return advancedRoutingSupported
}
@@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool {
}
func CheckFwmarkSupport() bool {
- // temporarily enable advanced routing to check fwmarks are supported
+ // temporarily enable advanced routing to check if fwmarks are supported
old := advancedRoutingSupported
advancedRoutingSupported = true
defer func() {
@@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool {
}
return true
}
+
+// SetVPNInterfaceName is a no-op on Linux
+func SetVPNInterfaceName(name string) {
+ // No-op on Linux - not needed for fwmark-based routing
+}
+
+// GetVPNInterfaceName returns empty string on Linux
+func GetVPNInterfaceName() string {
+ return ""
+}
diff --git a/client/net/env_windows.go b/client/net/env_windows.go
new file mode 100644
index 000000000..7e8868ba5
--- /dev/null
+++ b/client/net/env_windows.go
@@ -0,0 +1,67 @@
+//go:build windows
+
+package net
+
+import (
+ "os"
+ "strconv"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/iface/netstack"
+)
+
+var (
+ vpnInterfaceName string
+ vpnInitMutex sync.RWMutex
+
+ advancedRoutingSupported bool
+)
+
+func Init() {
+ advancedRoutingSupported = checkAdvancedRoutingSupport()
+}
+
+func checkAdvancedRoutingSupport() bool {
+ var err error
+ var legacyRouting bool
+ if val := os.Getenv(envUseLegacyRouting); val != "" {
+ legacyRouting, err = strconv.ParseBool(val)
+ if err != nil {
+ log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
+ }
+ }
+
+ if legacyRouting || netstack.IsEnabled() {
+ log.Info("advanced routing has been requested to be disabled")
+ return false
+ }
+
+ log.Info("system supports advanced routing")
+
+ return true
+}
+
+// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
+func AdvancedRouting() bool {
+ return advancedRoutingSupported
+}
+
+// GetVPNInterfaceName returns the stored VPN interface name
+func GetVPNInterfaceName() string {
+ vpnInitMutex.RLock()
+ defer vpnInitMutex.RUnlock()
+ return vpnInterfaceName
+}
+
+// SetVPNInterfaceName sets the VPN interface name for lazy initialization
+func SetVPNInterfaceName(name string) {
+ vpnInitMutex.Lock()
+ defer vpnInitMutex.Unlock()
+ vpnInterfaceName = name
+
+ if name != "" {
+ log.Infof("VPN interface name set to %s for route exclusion", name)
+ }
+}
diff --git a/client/net/hooks/hooks.go b/client/net/hooks/hooks.go
new file mode 100644
index 000000000..93d8e18ef
--- /dev/null
+++ b/client/net/hooks/hooks.go
@@ -0,0 +1,93 @@
+package hooks
+
+import (
+ "net/netip"
+ "slices"
+ "sync"
+
+ "github.com/google/uuid"
+)
+
+// ConnectionID provides a globally unique identifier for network connections.
+// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
+type ConnectionID string
+
+// GenerateConnID generates a unique identifier for each connection.
+func GenerateConnID() ConnectionID {
+ return ConnectionID(uuid.NewString())
+}
+
+type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error
+type CloseHookFunc func(connID ConnectionID) error
+type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
+
+var (
+ hooksMutex sync.RWMutex
+
+ writeHooks []WriteHookFunc
+ closeHooks []CloseHookFunc
+ addressRemoveHooks []AddressRemoveHookFunc
+)
+
+// AddWriteHook allows adding a new hook to be executed before writing/dialing.
+func AddWriteHook(hook WriteHookFunc) {
+ hooksMutex.Lock()
+ defer hooksMutex.Unlock()
+ writeHooks = append(writeHooks, hook)
+}
+
+// AddCloseHook allows adding a new hook to be executed on connection close.
+func AddCloseHook(hook CloseHookFunc) {
+ hooksMutex.Lock()
+ defer hooksMutex.Unlock()
+ closeHooks = append(closeHooks, hook)
+}
+
+// RemoveWriteHooks removes all write hooks.
+func RemoveWriteHooks() {
+ hooksMutex.Lock()
+ defer hooksMutex.Unlock()
+ writeHooks = nil
+}
+
+// RemoveCloseHooks removes all close hooks.
+func RemoveCloseHooks() {
+ hooksMutex.Lock()
+ defer hooksMutex.Unlock()
+ closeHooks = nil
+}
+
+// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed.
+func AddAddressRemoveHook(hook AddressRemoveHookFunc) {
+ hooksMutex.Lock()
+ defer hooksMutex.Unlock()
+ addressRemoveHooks = append(addressRemoveHooks, hook)
+}
+
+// RemoveAddressRemoveHooks removes all listener address hooks.
+func RemoveAddressRemoveHooks() {
+ hooksMutex.Lock()
+ defer hooksMutex.Unlock()
+ addressRemoveHooks = nil
+}
+
+// GetWriteHooks returns a copy of the current write hooks.
+func GetWriteHooks() []WriteHookFunc {
+ hooksMutex.RLock()
+ defer hooksMutex.RUnlock()
+ return slices.Clone(writeHooks)
+}
+
+// GetCloseHooks returns a copy of the current close hooks.
+func GetCloseHooks() []CloseHookFunc {
+ hooksMutex.RLock()
+ defer hooksMutex.RUnlock()
+ return slices.Clone(closeHooks)
+}
+
+// GetAddressRemoveHooks returns a copy of the current listener address remove hooks.
+func GetAddressRemoveHooks() []AddressRemoveHookFunc {
+ hooksMutex.RLock()
+ defer hooksMutex.RUnlock()
+ return slices.Clone(addressRemoveHooks)
+}
diff --git a/client/net/listen.go b/client/net/listen.go
new file mode 100644
index 000000000..da7262806
--- /dev/null
+++ b/client/net/listen.go
@@ -0,0 +1,47 @@
+//go:build !ios
+
+package net
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+
+ "github.com/pion/transport/v3"
+ log "github.com/sirupsen/logrus"
+)
+
+// ListenUDP listens on the network address and returns a transport.UDPConn
+// which includes support for write and close hooks.
+func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
+ if CustomRoutingDisabled() {
+ return net.ListenUDP(network, laddr)
+ }
+
+ conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
+ if err != nil {
+ return nil, fmt.Errorf("listen UDP: %w", err)
+ }
+
+ switch c := conn.(type) {
+ case *net.UDPConn:
+ // Advanced routing: plain connection
+ return c, nil
+ case *PacketConn:
+ // Legacy routing: wrapped connection for hooks
+ udpConn, ok := c.PacketConn.(*net.UDPConn)
+ if !ok {
+ if err := c.Close(); err != nil {
+ log.Errorf("Failed to close connection: %v", err)
+ }
+ return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn)
+ }
+ return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
+ }
+
+ if err := conn.Close(); err != nil {
+ log.Errorf("failed to close connection: %v", err)
+ }
+ return nil, fmt.Errorf("unexpected connection type: %T", conn)
+}
diff --git a/util/net/listen_ios.go b/client/net/listen_ios.go
similarity index 100%
rename from util/net/listen_ios.go
rename to client/net/listen_ios.go
diff --git a/util/net/listener.go b/client/net/listener.go
similarity index 81%
rename from util/net/listener.go
rename to client/net/listener.go
index f4d769f58..4c2f53c05 100644
--- a/util/net/listener.go
+++ b/client/net/listener.go
@@ -7,14 +7,12 @@ import (
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
type ListenerConfig struct {
- *net.ListenConfig
+ net.ListenConfig
}
// NewListener creates a new ListenerConfig instance.
func NewListener() *ListenerConfig {
- listener := &ListenerConfig{
- ListenConfig: &net.ListenConfig{},
- }
+ listener := &ListenerConfig{}
listener.init()
return listener
diff --git a/util/net/listener_init_android.go b/client/net/listener_init_android.go
similarity index 100%
rename from util/net/listener_init_android.go
rename to client/net/listener_init_android.go
diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go
new file mode 100644
index 000000000..4f8f17ab2
--- /dev/null
+++ b/client/net/listener_init_generic.go
@@ -0,0 +1,7 @@
+//go:build !linux && !windows
+
+package net
+
+func (l *ListenerConfig) init() {
+ // implemented on Linux, Android, and Windows only
+}
diff --git a/util/net/listener_init_linux.go b/client/net/listener_init_linux.go
similarity index 100%
rename from util/net/listener_init_linux.go
rename to client/net/listener_init_linux.go
diff --git a/client/net/listener_init_windows.go b/client/net/listener_init_windows.go
new file mode 100644
index 000000000..a9399b5f1
--- /dev/null
+++ b/client/net/listener_init_windows.go
@@ -0,0 +1,8 @@
+package net
+
+func (l *ListenerConfig) init() {
+ // TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses.
+ // For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case
+ // the interface will be selected that serves the default route.
+ l.ListenConfig.Control = applyUnicastIFToSocket
+}
diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go
new file mode 100644
index 000000000..0bb5ad67d
--- /dev/null
+++ b/client/net/listener_listen.go
@@ -0,0 +1,153 @@
+//go:build !ios
+
+package net
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/netip"
+ "sync"
+
+ "github.com/hashicorp/go-multierror"
+ log "github.com/sirupsen/logrus"
+
+ nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/client/internal/routemanager/util"
+ "github.com/netbirdio/netbird/client/net/hooks"
+)
+
+// ListenPacket listens on the network address and returns a PacketConn
+// which includes support for write hooks.
+func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
+ if CustomRoutingDisabled() || AdvancedRouting() {
+ return l.ListenConfig.ListenPacket(ctx, network, address)
+ }
+
+ pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
+ if err != nil {
+ return nil, fmt.Errorf("listen packet: %w", err)
+ }
+ connID := hooks.GenerateConnID()
+
+ return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
+}
+
+// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
+type PacketConn struct {
+ net.PacketConn
+ ID hooks.ConnectionID
+ seenAddrs *sync.Map
+}
+
+// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
+func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
+ if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
+ log.Errorf("Failed to call write hooks: %v", err)
+ }
+ return c.PacketConn.WriteTo(b, addr)
+}
+
+// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
+func (c *PacketConn) Close() error {
+ defer c.seenAddrs.Clear()
+ return closeConn(c.ID, c.PacketConn)
+}
+
+// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
+type UDPConn struct {
+ *net.UDPConn
+ ID hooks.ConnectionID
+ seenAddrs *sync.Map
+}
+
+// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
+func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
+ if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
+ log.Errorf("Failed to call write hooks: %v", err)
+ }
+ return c.UDPConn.WriteTo(b, addr)
+}
+
+// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
+func (c *UDPConn) Close() error {
+ defer c.seenAddrs.Clear()
+ return closeConn(c.ID, c.UDPConn)
+}
+
+// RemoveAddress removes an address from the seen cache and triggers removal hooks.
+func (c *PacketConn) RemoveAddress(addr string) {
+ if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
+ return
+ }
+
+ ipStr, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ log.Errorf("Error splitting IP address and port: %v", err)
+ return
+ }
+
+ ipAddr, err := netip.ParseAddr(ipStr)
+ if err != nil {
+ log.Errorf("Error parsing IP address %s: %v", ipStr, err)
+ return
+ }
+
+ prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen())
+
+ addressRemoveHooks := hooks.GetAddressRemoveHooks()
+ if len(addressRemoveHooks) == 0 {
+ return
+ }
+
+ for _, hook := range addressRemoveHooks {
+ if err := hook(c.ID, prefix); err != nil {
+ log.Errorf("Error executing listener address remove hook: %v", err)
+ }
+ }
+}
+
+// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality
+func WrapPacketConn(conn net.PacketConn) net.PacketConn {
+ if AdvancedRouting() {
+ // hooks not required for advanced routing
+ return conn
+ }
+ return &PacketConn{
+ PacketConn: conn,
+ ID: hooks.GenerateConnID(),
+ seenAddrs: &sync.Map{},
+ }
+}
+
+func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error {
+ if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded {
+ return nil
+ }
+
+ writeHooks := hooks.GetWriteHooks()
+ if len(writeHooks) == 0 {
+ return nil
+ }
+
+ udpAddr, ok := addr.(*net.UDPAddr)
+ if !ok {
+ return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr)
+ }
+
+ prefix, err := util.GetPrefixFromIP(udpAddr.IP)
+ if err != nil {
+ return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err)
+ }
+
+ log.Debugf("Listener resolved IP for %s: %s", addr, prefix)
+
+ var merr *multierror.Error
+ for _, hook := range writeHooks {
+ if err := hook(id, prefix); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err))
+ }
+ }
+
+ return nberrors.FormatErrorOrNil(merr)
+}
diff --git a/util/net/listener_listen_ios.go b/client/net/listener_listen_ios.go
similarity index 100%
rename from util/net/listener_listen_ios.go
rename to client/net/listener_listen_ios.go
diff --git a/util/net/net.go b/client/net/net.go
similarity index 81%
rename from util/net/net.go
rename to client/net/net.go
index fdcf4ee6a..a97de9d59 100644
--- a/util/net/net.go
+++ b/client/net/net.go
@@ -5,8 +5,6 @@ import (
"math/big"
"net"
"net/netip"
-
- "github.com/google/uuid"
)
const (
@@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool {
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
}
-// ConnectionID provides a globally unique identifier for network connections.
-// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
-type ConnectionID string
-
-type AddHookFunc func(connID ConnectionID, IP net.IP) error
-type RemoveHookFunc func(connID ConnectionID) error
-
-// GenerateConnID generates a unique identifier for each connection.
-func GenerateConnID() ConnectionID {
- return ConnectionID(uuid.NewString())
-}
-
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
var endIP net.IP
addr := network.Addr().AsSlice()
diff --git a/util/net/net_linux.go b/client/net/net_linux.go
similarity index 100%
rename from util/net/net_linux.go
rename to client/net/net_linux.go
diff --git a/util/net/net_test.go b/client/net/net_test.go
similarity index 100%
rename from util/net/net_test.go
rename to client/net/net_test.go
diff --git a/client/net/net_windows.go b/client/net/net_windows.go
new file mode 100644
index 000000000..649d83aaf
--- /dev/null
+++ b/client/net/net_windows.go
@@ -0,0 +1,284 @@
+package net
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/netip"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+ "unsafe"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+const (
+ // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
+ IpUnicastIf = 31
+ Ipv6UnicastIf = 31
+
+ // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options
+ Ipv6V6only = 27
+)
+
+// GetBestInterfaceFunc is set at runtime to avoid import cycle
+var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error)
+
+// nativeToBigEndian converts a uint32 from native byte order to big-endian
+func nativeToBigEndian(v uint32) uint32 {
+ return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24
+}
+
+// parseDestinationAddress parses the destination address from various formats
+func parseDestinationAddress(network, address string) (netip.Addr, error) {
+ if address == "" {
+ if strings.HasSuffix(network, "6") {
+ return netip.IPv6Unspecified(), nil
+ }
+ return netip.IPv4Unspecified(), nil
+ }
+
+ if addrPort, err := netip.ParseAddrPort(address); err == nil {
+ return addrPort.Addr(), nil
+ }
+
+ if dest, err := netip.ParseAddr(address); err == nil {
+ return dest, nil
+ }
+
+ host, _, err := net.SplitHostPort(address)
+ if err != nil {
+ // No port, treat whole string as host
+ host = address
+ }
+
+ if host == "" {
+ if strings.HasSuffix(network, "6") {
+ return netip.IPv6Unspecified(), nil
+ }
+ return netip.IPv4Unspecified(), nil
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
+ if err != nil || len(ips) == 0 {
+ return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err)
+ }
+
+ dest, ok := netip.AddrFromSlice(ips[0].IP)
+ if !ok {
+ return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP)
+ }
+
+ if ips[0].Zone != "" {
+ dest = dest.WithZone(ips[0].Zone)
+ }
+
+ return dest, nil
+}
+
+func getInterfaceFromZone(zone string) *net.Interface {
+ if zone == "" {
+ return nil
+ }
+
+ idx, err := strconv.Atoi(zone)
+ if err != nil {
+ log.Debugf("invalid zone format for Windows (expected numeric): %s", zone)
+ return nil
+ }
+
+ iface, err := net.InterfaceByIndex(idx)
+ if err != nil {
+ log.Debugf("failed to get interface by index %d from zone: %v", idx, err)
+ return nil
+ }
+
+ return iface
+}
+
+type interfaceSelection struct {
+ iface4 *net.Interface
+ iface6 *net.Interface
+}
+
+func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection {
+ iface := getInterfaceFromZone(zone)
+ if iface == nil {
+ return nil
+ }
+
+ if dest.Is6() {
+ return &interfaceSelection{iface6: iface}
+ }
+ return &interfaceSelection{iface4: iface}
+}
+
+func selectInterfaceForUnspecified() (*interfaceSelection, error) {
+ if GetBestInterfaceFunc == nil {
+ return nil, errors.New("GetBestInterfaceFunc not initialized")
+ }
+
+ var result interfaceSelection
+ vpnIfaceName := GetVPNInterfaceName()
+
+ if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil {
+ result.iface4 = iface4
+ } else {
+ log.Debugf("No IPv4 default route found: %v", err)
+ }
+
+ if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil {
+ result.iface6 = iface6
+ } else {
+ log.Debugf("No IPv6 default route found: %v", err)
+ }
+
+ if result.iface4 == nil && result.iface6 == nil {
+ return nil, errors.New("no default routes found")
+ }
+
+ return &result, nil
+}
+
+func selectInterface(dest netip.Addr) (*interfaceSelection, error) {
+ if zone := dest.Zone(); zone != "" {
+ if selection := selectInterfaceForZone(dest, zone); selection != nil {
+ return selection, nil
+ }
+ }
+
+ if dest.IsUnspecified() {
+ return selectInterfaceForUnspecified()
+ }
+
+ if GetBestInterfaceFunc == nil {
+ return nil, errors.New("GetBestInterfaceFunc not initialized")
+ }
+
+ iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName())
+ if err != nil {
+ return nil, fmt.Errorf("find route for %s: %w", dest, err)
+ }
+
+ if dest.Is6() {
+ return &interfaceSelection{iface6: iface}, nil
+ }
+ return &interfaceSelection{iface4: iface}, nil
+}
+
+func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error {
+ ifaceIndexBE := nativeToBigEndian(uint32(iface.Index))
+ if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil {
+ return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
+ }
+ return nil
+}
+
+func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error {
+ if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil {
+ return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
+ }
+ return nil
+}
+
+func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error {
+ // The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.)
+ // Never generic ones (udp, tcp, ip)
+
+ switch {
+ case strings.HasSuffix(network, "4"):
+ // IPv4-only socket (udp4, tcp4, ip4)
+ return setUnicastIfIPv4(fd, network, selection, address)
+
+ case strings.HasSuffix(network, "6"):
+ // IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only
+ return setUnicastIfIPv6(fd, network, selection, address)
+ }
+
+ // Shouldn't reach here based on Go's documented behavior
+ return fmt.Errorf("unexpected network type: %s", network)
+}
+
+func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error {
+ if selection.iface4 == nil {
+ return nil
+ }
+
+ if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
+ return err
+ }
+
+ log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address)
+ return nil
+}
+
+func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error {
+ isDualStack := checkDualStack(fd)
+
+ // For dual-stack sockets, also set the IPv4 option
+ if isDualStack && selection.iface4 != nil {
+ if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
+ return err
+ }
+ log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address)
+ }
+
+ if selection.iface6 == nil {
+ return nil
+ }
+
+ if err := setIPv6UnicastIF(fd, selection.iface6); err != nil {
+ return err
+ }
+
+ log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address)
+ return nil
+}
+
+func checkDualStack(fd uintptr) bool {
+ var v6Only int
+ v6OnlyLen := int32(unsafe.Sizeof(v6Only))
+ err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen)
+ return err == nil && v6Only == 0
+}
+
+// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address
+func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error {
+ if !AdvancedRouting() {
+ return nil
+ }
+
+ dest, err := parseDestinationAddress(network, address)
+ if err != nil {
+ return err
+ }
+
+ dest = dest.Unmap()
+
+ if !dest.IsValid() {
+ return fmt.Errorf("invalid destination address for %s", address)
+ }
+
+ selection, err := selectInterface(dest)
+ if err != nil {
+ return err
+ }
+
+ var controlErr error
+ err = c.Control(func(fd uintptr) {
+ controlErr = setUnicastIf(fd, network, selection, address)
+ })
+
+ if err != nil {
+ return fmt.Errorf("control: %w", err)
+ }
+
+ return controlErr
+}
diff --git a/util/net/protectsocket_android.go b/client/net/protectsocket_android.go
similarity index 100%
rename from util/net/protectsocket_android.go
rename to client/net/protectsocket_android.go
diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh
index 2422d2683..7c9fa021a 100755
--- a/client/netbird-entrypoint.sh
+++ b/client/netbird-entrypoint.sh
@@ -2,7 +2,7 @@
set -eEuo pipefail
: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"}
-: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"}
+: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"}
NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}"
export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}"
service_pids=()
@@ -39,7 +39,7 @@ wait_for_message() {
info "not waiting for log line ${message@Q} due to zero timeout."
elif test -n "${log_file_path}"; then
info "waiting for log line ${message@Q} for ${timeout} seconds..."
- grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null)
+ grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null)
else
info "log file unsupported, sleeping for ${timeout} seconds..."
sleep "${timeout}"
@@ -81,7 +81,7 @@ wait_for_daemon_startup() {
login_if_needed() {
local timeout="${1}"
- if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then
+ if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then
info "already logged in, skipping 'netbird up'..."
else
info "logging in..."
diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go
index 13cc8cdbc..0ea294c53 100644
--- a/client/proto/daemon.pb.go
+++ b/client/proto/daemon.pb.go
@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
-// protoc v5.29.3
+// protoc v6.32.1
// source: daemon.proto
package proto
@@ -826,8 +826,10 @@ type StatusRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
- unknownFields protoimpl.UnknownFields
- sizeCache protoimpl.SizeCache
+ // the UI do not using this yet, but CLIs could use it to wait until the status is ready
+ WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
}
func (x *StatusRequest) Reset() {
@@ -874,6 +876,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool {
return false
}
+func (x *StatusRequest) GetWaitForReady() bool {
+ if x != nil && x.WaitForReady != nil {
+ return *x.WaitForReady
+ }
+ return false
+}
+
type StatusResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// status of the server.
@@ -4904,10 +4913,12 @@ const file_daemon_proto_rawDesc = "" +
"\f_profileNameB\v\n" +
"\t_username\"\f\n" +
"\n" +
- "UpResponse\"g\n" +
+ "UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" +
"\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" +
- "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" +
+ "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" +
+ "\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" +
+ "\r_waitForReady\"\x82\x01\n" +
"\x0eStatusResponse\x12\x16\n" +
"\x06status\x18\x01 \x01(\tR\x06status\x122\n" +
"\n" +
@@ -5491,6 +5502,7 @@ func file_daemon_proto_init() {
}
file_daemon_proto_msgTypes[1].OneofWrappers = []any{}
file_daemon_proto_msgTypes[5].OneofWrappers = []any{}
+ file_daemon_proto_msgTypes[7].OneofWrappers = []any{}
file_daemon_proto_msgTypes[26].OneofWrappers = []any{
(*PortInfo_Port)(nil),
(*PortInfo_Range_)(nil),
diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto
index de62493af..2d904cb32 100644
--- a/client/proto/daemon.proto
+++ b/client/proto/daemon.proto
@@ -194,6 +194,8 @@ message UpResponse {}
message StatusRequest{
bool getFullPeerStatus = 1;
bool shouldRunProbes = 2;
+ // the UI do not using this yet, but CLIs could use it to wait until the status is ready
+ optional bool waitForReady = 3;
}
message StatusResponse{
diff --git a/client/server/server.go b/client/server/server.go
index 4b0c59e4d..864b2c506 100644
--- a/client/server/server.go
+++ b/client/server/server.go
@@ -65,6 +65,9 @@ type Server struct {
mutex sync.Mutex
config *profilemanager.Config
proto.UnimplementedDaemonServiceServer
+ clientRunning bool // protected by mutex
+ clientRunningChan chan struct{}
+ clientGiveUpChan chan struct{}
connectClient *internal.ConnectClient
@@ -103,6 +106,11 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
func (s *Server) Start() error {
s.mutex.Lock()
defer s.mutex.Unlock()
+
+ if s.clientRunning {
+ return nil
+ }
+
state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil {
@@ -172,8 +180,10 @@ func (s *Server) Start() error {
return nil
}
- go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
-
+ s.clientRunning = true
+ s.clientRunningChan = make(chan struct{})
+ s.clientGiveUpChan = make(chan struct{})
+ go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
@@ -204,12 +214,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
-func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status,
- runningChan chan struct{},
-) {
- backOff := getConnectWithBackoff(ctx)
- retryStarted := false
+func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
+ defer func() {
+ s.mutex.Lock()
+ s.clientRunning = false
+ s.mutex.Unlock()
+ }()
+ if s.config.DisableAutoConnect {
+ if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
+ log.Debugf("run client connection exited with error: %v", err)
+ }
+ log.Tracef("client connection exited")
+ return
+ }
+
+ backOff := getConnectWithBackoff(ctx)
go func() {
t := time.NewTicker(24 * time.Hour)
for {
@@ -218,89 +238,36 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage
t.Stop()
return
case <-t.C:
- if retryStarted {
-
- mgmtState := statusRecorder.GetManagementState()
- signalState := statusRecorder.GetSignalState()
- if mgmtState.Connected && signalState.Connected {
- log.Tracef("resetting status")
- retryStarted = false
- } else {
- log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
- }
+ mgmtState := statusRecorder.GetManagementState()
+ signalState := statusRecorder.GetSignalState()
+ if mgmtState.Connected && signalState.Connected {
+ log.Tracef("resetting status")
+ backOff.Reset()
+ } else {
+ log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
}
}
}
}()
runOperation := func() error {
- log.Tracef("running client connection")
- s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
- s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
-
- err := s.connectClient.Run(runningChan)
+ err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
+ return err
}
- if config.DisableAutoConnect {
- return backoff.Permanent(err)
- }
-
- if !retryStarted {
- retryStarted = true
- backOff.Reset()
- }
-
- log.Tracef("client connection exited")
- return fmt.Errorf("client connection exited")
+ log.Tracef("client connection exited gracefully, do not need to retry")
+ return nil
}
- err := backoff.Retry(runOperation, backOff)
- if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled {
- log.Errorf("received an error when trying to connect: %v", err)
- } else {
- log.Tracef("retry canceled")
- }
-}
-
-// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
-func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
- initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
- maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
- maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
- multiplier := defaultRetryMultiplier
-
- if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
- // parse the multiplier from the environment variable string value to float64
- value, err := strconv.ParseFloat(envValue, 64)
- if err != nil {
- log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
- } else {
- multiplier = value
- }
+ if err := backoff.Retry(runOperation, backOff); err != nil {
+ log.Errorf("operation failed: %v", err)
}
- return backoff.WithContext(&backoff.ExponentialBackOff{
- InitialInterval: initialInterval,
- RandomizationFactor: 1,
- Multiplier: multiplier,
- MaxInterval: maxInterval,
- MaxElapsedTime: maxElapsedTime, // 14 days
- Stop: backoff.Stop,
- Clock: backoff.SystemClock,
- }, ctx)
-}
-
-// parseEnvDuration parses the environment variable and returns the duration
-func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
- if envValue := os.Getenv(envVar); envValue != "" {
- if duration, err := time.ParseDuration(envValue); err == nil {
- return duration
- }
- log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
+ if giveUpChan != nil {
+ close(giveUpChan)
}
- return defaultDuration
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
@@ -423,7 +390,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
if s.actCancel != nil {
s.actCancel()
}
- ctx, cancel := context.WithCancel(s.rootCtx)
+ ctx, cancel := context.WithCancel(callerCtx)
md, ok := metadata.FromIncomingContext(callerCtx)
if ok {
@@ -433,11 +400,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.actCancel = cancel
s.mutex.Unlock()
- if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil {
+ if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err)
}
- state := internal.CtxGetState(ctx)
+ state := internal.CtxGetState(s.rootCtx)
defer func() {
status, err := state.Status()
if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) {
@@ -650,6 +617,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
// Up starts engine work in the daemon.
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
s.mutex.Lock()
+ if s.clientRunning {
+ state := internal.CtxGetState(s.rootCtx)
+ status, err := state.Status()
+ if err != nil {
+ s.mutex.Unlock()
+ return nil, err
+ }
+ if status == internal.StatusNeedsLogin {
+ s.actCancel()
+ }
+ s.mutex.Unlock()
+
+ return s.waitForUp(callerCtx)
+ }
defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
@@ -665,16 +646,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
if err != nil {
return nil, err
}
+
if status != internal.StatusIdle {
return nil, fmt.Errorf("up already in progress: current status %s", status)
}
- // it should be nil here, but .
+ // it should be nil here, but in case it isn't we cancel it.
if s.actCancel != nil {
s.actCancel()
}
ctx, cancel := context.WithCancel(s.rootCtx)
-
md, ok := metadata.FromIncomingContext(callerCtx)
if ok {
ctx = metadata.NewOutgoingContext(ctx, md)
@@ -717,23 +698,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
+ s.clientRunning = true
+ s.clientRunningChan = make(chan struct{})
+ s.clientGiveUpChan = make(chan struct{})
+ go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
+
+ return s.waitForUp(callerCtx)
+}
+
+// todo: handle potential race conditions
+func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel()
- runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal
- go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
- for {
- select {
- case <-runningChan:
- s.isSessionActive.Store(true)
- return &proto.UpResponse{}, nil
- case <-callerCtx.Done():
- log.Debug("context done, stopping the wait for engine to become ready")
- return nil, callerCtx.Err()
- case <-timeoutCtx.Done():
- log.Debug("up is timed out, stopping the wait for engine to become ready")
- return nil, timeoutCtx.Err()
- }
+ select {
+ case <-s.clientGiveUpChan:
+ return nil, fmt.Errorf("client gave up to connect")
+ case <-s.clientRunningChan:
+ s.isSessionActive.Store(true)
+ return &proto.UpResponse{}, nil
+ case <-callerCtx.Done():
+ log.Debug("context done, stopping the wait for engine to become ready")
+ return nil, callerCtx.Err()
+ case <-timeoutCtx.Done():
+ log.Debug("up is timed out, stopping the wait for engine to become ready")
+ return nil, timeoutCtx.Err()
}
}
@@ -1007,12 +996,46 @@ func (s *Server) Status(
ctx context.Context,
msg *proto.StatusRequest,
) (*proto.StatusResponse, error) {
- if ctx.Err() != nil {
- return nil, ctx.Err()
- }
-
s.mutex.Lock()
- defer s.mutex.Unlock()
+ clientRunning := s.clientRunning
+ s.mutex.Unlock()
+
+ if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
+ state := internal.CtxGetState(s.rootCtx)
+ status, err := state.Status()
+ if err != nil {
+ return nil, err
+ }
+
+ if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
+ s.actCancel()
+ }
+
+ ticker := time.NewTicker(1 * time.Second)
+ defer ticker.Stop()
+ loop:
+ for {
+ select {
+ case <-s.clientGiveUpChan:
+ ticker.Stop()
+ break loop
+ case <-s.clientRunningChan:
+ ticker.Stop()
+ break loop
+ case <-ticker.C:
+ status, err := state.Status()
+ if err != nil {
+ continue
+ }
+ if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
+ s.actCancel()
+ }
+ continue
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ }
+ }
status, err := internal.CtxGetState(s.rootCtx).Status()
if err != nil {
@@ -1194,6 +1217,134 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
}, nil
}
+// AddProfile adds a new profile to the daemon.
+func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ if s.checkProfilesDisabled() {
+ return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
+ }
+
+ if msg.ProfileName == "" || msg.Username == "" {
+ return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
+ }
+
+ if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
+ log.Errorf("failed to create profile: %v", err)
+ return nil, fmt.Errorf("failed to create profile: %w", err)
+ }
+
+ return &proto.AddProfileResponse{}, nil
+}
+
+// RemoveProfile removes a profile from the daemon.
+func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
+ return nil, err
+ }
+
+ if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
+ log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
+ }
+
+ if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
+ log.Errorf("failed to remove profile: %v", err)
+ return nil, fmt.Errorf("failed to remove profile: %w", err)
+ }
+
+ return &proto.RemoveProfileResponse{}, nil
+}
+
+// ListProfiles lists all profiles in the daemon.
+func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ if msg.Username == "" {
+ return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
+ }
+
+ profiles, err := s.profileManager.ListProfiles(msg.Username)
+ if err != nil {
+ log.Errorf("failed to list profiles: %v", err)
+ return nil, fmt.Errorf("failed to list profiles: %w", err)
+ }
+
+ response := &proto.ListProfilesResponse{
+ Profiles: make([]*proto.Profile, len(profiles)),
+ }
+ for i, profile := range profiles {
+ response.Profiles[i] = &proto.Profile{
+ Name: profile.Name,
+ IsActive: profile.IsActive,
+ }
+ }
+
+ return response, nil
+}
+
+// GetActiveProfile returns the active profile in the daemon.
+func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ activeProfile, err := s.profileManager.GetActiveProfileState()
+ if err != nil {
+ log.Errorf("failed to get active profile state: %v", err)
+ return nil, fmt.Errorf("failed to get active profile state: %w", err)
+ }
+
+ return &proto.GetActiveProfileResponse{
+ ProfileName: activeProfile.Name,
+ Username: activeProfile.Username,
+ }, nil
+}
+
+// GetFeatures returns the features supported by the daemon.
+func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ features := &proto.GetFeaturesResponse{
+ DisableProfiles: s.checkProfilesDisabled(),
+ DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
+ }
+
+ return features, nil
+}
+
+func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
+ log.Tracef("running client connection")
+ s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
+ s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
+ if err := s.connectClient.Run(runningChan); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *Server) checkProfilesDisabled() bool {
+ // Check if the environment variable is set to disable profiles
+ if s.profilesDisabled {
+ return true
+ }
+
+ return false
+}
+
+func (s *Server) checkUpdateSettingsDisabled() bool {
+ // Check if the environment variable is set to disable profiles
+ if s.updateSettingsDisabled {
+ return true
+ }
+
+ return false
+}
+
func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()
@@ -1205,6 +1356,45 @@ func (s *Server) onSessionExpire() {
}
}
+// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
+func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
+ initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
+ maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
+ maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
+ multiplier := defaultRetryMultiplier
+
+ if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
+ // parse the multiplier from the environment variable string value to float64
+ value, err := strconv.ParseFloat(envValue, 64)
+ if err != nil {
+ log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
+ } else {
+ multiplier = value
+ }
+ }
+
+ return backoff.WithContext(&backoff.ExponentialBackOff{
+ InitialInterval: initialInterval,
+ RandomizationFactor: 1,
+ Multiplier: multiplier,
+ MaxInterval: maxInterval,
+ MaxElapsedTime: maxElapsedTime, // 14 days
+ Stop: backoff.Stop,
+ Clock: backoff.SystemClock,
+ }, ctx)
+}
+
+// parseEnvDuration parses the environment variable and returns the duration
+func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
+ if envValue := os.Getenv(envVar); envValue != "" {
+ if duration, err := time.ParseDuration(envValue); err == nil {
+ return duration
+ }
+ log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
+ }
+ return defaultDuration
+}
+
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
@@ -1320,121 +1510,3 @@ func sendTerminalNotification() error {
return wallCmd.Wait()
}
-
-// AddProfile adds a new profile to the daemon.
-func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
- s.mutex.Lock()
- defer s.mutex.Unlock()
-
- if s.checkProfilesDisabled() {
- return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
- }
-
- if msg.ProfileName == "" || msg.Username == "" {
- return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
- }
-
- if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
- log.Errorf("failed to create profile: %v", err)
- return nil, fmt.Errorf("failed to create profile: %w", err)
- }
-
- return &proto.AddProfileResponse{}, nil
-}
-
-// RemoveProfile removes a profile from the daemon.
-func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
- s.mutex.Lock()
- defer s.mutex.Unlock()
-
- if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
- return nil, err
- }
-
- if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
- log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
- }
-
- if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
- log.Errorf("failed to remove profile: %v", err)
- return nil, fmt.Errorf("failed to remove profile: %w", err)
- }
-
- return &proto.RemoveProfileResponse{}, nil
-}
-
-// ListProfiles lists all profiles in the daemon.
-func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
- s.mutex.Lock()
- defer s.mutex.Unlock()
-
- if msg.Username == "" {
- return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
- }
-
- profiles, err := s.profileManager.ListProfiles(msg.Username)
- if err != nil {
- log.Errorf("failed to list profiles: %v", err)
- return nil, fmt.Errorf("failed to list profiles: %w", err)
- }
-
- response := &proto.ListProfilesResponse{
- Profiles: make([]*proto.Profile, len(profiles)),
- }
- for i, profile := range profiles {
- response.Profiles[i] = &proto.Profile{
- Name: profile.Name,
- IsActive: profile.IsActive,
- }
- }
-
- return response, nil
-}
-
-// GetActiveProfile returns the active profile in the daemon.
-func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
- s.mutex.Lock()
- defer s.mutex.Unlock()
-
- activeProfile, err := s.profileManager.GetActiveProfileState()
- if err != nil {
- log.Errorf("failed to get active profile state: %v", err)
- return nil, fmt.Errorf("failed to get active profile state: %w", err)
- }
-
- return &proto.GetActiveProfileResponse{
- ProfileName: activeProfile.Name,
- Username: activeProfile.Username,
- }, nil
-}
-
-// GetFeatures returns the features supported by the daemon.
-func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
- s.mutex.Lock()
- defer s.mutex.Unlock()
-
- features := &proto.GetFeaturesResponse{
- DisableProfiles: s.checkProfilesDisabled(),
- DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
- }
-
- return features, nil
-}
-
-func (s *Server) checkProfilesDisabled() bool {
- // Check if the environment variable is set to disable profiles
- if s.profilesDisabled {
- return true
- }
-
- return false
-}
-
-func (s *Server) checkUpdateSettingsDisabled() bool {
- // Check if the environment variable is set to disable profiles
- if s.updateSettingsDisabled {
- return true
- }
-
- return false
-}
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 24ff9fb0c..755925003 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -10,25 +10,25 @@ import (
"time"
"github.com/golang/mock/gomock"
- "github.com/stretchr/testify/require"
- "go.opentelemetry.io/otel"
-
- "github.com/netbirdio/management-integrations/integrations"
- "github.com/netbirdio/netbird/management/internals/server/config"
- "github.com/netbirdio/netbird/management/server/groups"
-
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
+ "github.com/netbirdio/management-integrations/integrations"
+
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
+ "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
+ "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
- s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
+ s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
@@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) {
profName := "default"
+ u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
+ require.NoError(t, err)
+
ic := profilemanager.ConfigInput{
- ConfigPath: filepath.Join(tempDir, profName+".json"),
+ ConfigPath: filepath.Join(tempDir, profName+".json"),
+ ManagementURL: u.String(),
}
_, err = profilemanager.UpdateOrCreateConfig(ic)
@@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) {
}
s := New(ctx, "console", "", false, false)
-
err = s.Start()
require.NoError(t, err)
- u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
- require.NoError(t, err)
- s.config = &profilemanager.Config{
- ManagementURL: u,
- }
-
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
@@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) {
Username: &currUser.Username,
}
_, err = s.Up(upCtx, upReq)
+ log.Errorf("error from Up: %v", err)
assert.Contains(t, err.Error(), "context deadline exceeded")
}
@@ -294,15 +292,20 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
- ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
+
+ ctrl := gomock.NewController(t)
+ t.Cleanup(ctrl.Finish)
+
+ permissionsManagerMock := permissions.NewMockManager(ctrl)
+ peersManager := peers.NewManager(store, permissionsManagerMock)
+ settingsManagerMock := settings.NewMockManager(ctrl)
+
+ ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- ctrl := gomock.NewController(t)
- t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
- permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
diff --git a/client/system/info.go b/client/system/info.go
index 90abf864b..1e4342e34 100644
--- a/client/system/info.go
+++ b/client/system/info.go
@@ -6,6 +6,7 @@ import (
"net/netip"
"strings"
+ log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -114,14 +115,6 @@ func (i *Info) SetFlags(
}
}
-// StaticInfo is an object that contains machine information that does not change
-type StaticInfo struct {
- SystemSerialNumber string
- SystemProductName string
- SystemManufacturer string
- Environment Environment
-}
-
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)
@@ -199,6 +192,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
+ log.Debugf("gathering system information with checks: %d", len(checks))
processCheckPaths := make([]string, 0)
for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
@@ -208,16 +202,11 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
if err != nil {
return nil, err
}
+ log.Debugf("gathering process check information completed")
info := GetInfo(ctx)
info.Files = files
+ log.Debugf("all system information gathered successfully")
return info, nil
}
-
-// UpdateStaticInfo asynchronously updates static system and platform information
-func UpdateStaticInfo() {
- go func() {
- _ = updateStaticInfo()
- }()
-}
diff --git a/client/system/info_android.go b/client/system/info_android.go
index 56fe0741d..78895bfa8 100644
--- a/client/system/info_android.go
+++ b/client/system/info_android.go
@@ -15,6 +15,11 @@ import (
"github.com/netbirdio/netbird/version"
)
+// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
+func UpdateStaticInfoAsync() {
+ // do nothing
+}
+
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
kernel := "android"
diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go
index f105ada60..caa344737 100644
--- a/client/system/info_darwin.go
+++ b/client/system/info_darwin.go
@@ -19,6 +19,10 @@ import (
"github.com/netbirdio/netbird/version"
)
+func UpdateStaticInfoAsync() {
+ go updateStaticInfo()
+}
+
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
utsname := unix.Utsname{}
@@ -41,7 +45,7 @@ func GetInfo(ctx context.Context) *Info {
}
start := time.Now()
- si := updateStaticInfo()
+ si := getStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
}
diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go
index bed6711de..8e1353151 100644
--- a/client/system/info_freebsd.go
+++ b/client/system/info_freebsd.go
@@ -18,6 +18,11 @@ import (
"github.com/netbirdio/netbird/version"
)
+// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
+func UpdateStaticInfoAsync() {
+ // do nothing
+}
+
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
out := _getInfo()
diff --git a/client/system/info_ios.go b/client/system/info_ios.go
index 897ec0a35..705c37920 100644
--- a/client/system/info_ios.go
+++ b/client/system/info_ios.go
@@ -10,6 +10,11 @@ import (
"github.com/netbirdio/netbird/version"
)
+// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
+func UpdateStaticInfoAsync() {
+ // do nothing
+}
+
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
diff --git a/client/system/info_linux.go b/client/system/info_linux.go
index 9bfc82009..6c7a23b95 100644
--- a/client/system/info_linux.go
+++ b/client/system/info_linux.go
@@ -23,6 +23,10 @@ var (
getSystemInfo = defaultSysInfoImplementation
)
+func UpdateStaticInfoAsync() {
+ go updateStaticInfo()
+}
+
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
info := _getInfo()
@@ -48,7 +52,7 @@ func GetInfo(ctx context.Context) *Info {
}
start := time.Now()
- si := updateStaticInfo()
+ si := getStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
}
diff --git a/client/system/info_windows.go b/client/system/info_windows.go
index 6f05ded20..d7f8f30aa 100644
--- a/client/system/info_windows.go
+++ b/client/system/info_windows.go
@@ -2,187 +2,51 @@ package system
import (
"context"
- "fmt"
"os"
"runtime"
- "strings"
"time"
log "github.com/sirupsen/logrus"
- "github.com/yusufpapurcu/wmi"
- "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/version"
)
-type Win32_OperatingSystem struct {
- Caption string
-}
-
-type Win32_ComputerSystem struct {
- Manufacturer string
-}
-
-type Win32_ComputerSystemProduct struct {
- Name string
-}
-
-type Win32_BIOS struct {
- SerialNumber string
+func UpdateStaticInfoAsync() {
+ go updateStaticInfo()
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
- osName, osVersion := getOSNameAndVersion()
- buildVersion := getBuildVersion()
-
- addrs, err := networkAddresses()
- if err != nil {
- log.Warnf("failed to discover network addresses: %s", err)
- }
-
start := time.Now()
- si := updateStaticInfo()
+ si := getStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
}
gio := &Info{
Kernel: "windows",
- OSVersion: osVersion,
+ OSVersion: si.OSVersion,
Platform: "unknown",
- OS: osName,
+ OS: si.OSName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
- KernelVersion: buildVersion,
- NetworkAddresses: addrs,
+ KernelVersion: si.BuildVersion,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
}
+ addrs, err := networkAddresses()
+ if err != nil {
+ log.Warnf("failed to discover network addresses: %s", err)
+ } else {
+ gio.NetworkAddresses = addrs
+ }
+
systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
-
return gio
}
-
-func sysInfo() (serialNumber string, productName string, manufacturer string) {
- var err error
- serialNumber, err = sysNumber()
- if err != nil {
- log.Warnf("failed to get system serial number: %s", err)
- }
-
- productName, err = sysProductName()
- if err != nil {
- log.Warnf("failed to get system product name: %s", err)
- }
-
- manufacturer, err = sysManufacturer()
- if err != nil {
- log.Warnf("failed to get system manufacturer: %s", err)
- }
-
- return serialNumber, productName, manufacturer
-}
-
-func getOSNameAndVersion() (string, string) {
- var dst []Win32_OperatingSystem
- query := wmi.CreateQuery(&dst, "")
- err := wmi.Query(query, &dst)
- if err != nil {
- log.Error(err)
- return "Windows", getBuildVersion()
- }
-
- if len(dst) == 0 {
- return "Windows", getBuildVersion()
- }
-
- split := strings.Split(dst[0].Caption, " ")
-
- if len(split) <= 3 {
- return "Windows", getBuildVersion()
- }
-
- name := split[1]
- version := split[2]
- if split[2] == "Server" {
- name = fmt.Sprintf("%s %s", split[1], split[2])
- version = split[3]
- }
-
- return name, version
-}
-
-func getBuildVersion() string {
- k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
- if err != nil {
- log.Error(err)
- return "0.0.0.0"
- }
- defer func() {
- deferErr := k.Close()
- if deferErr != nil {
- log.Error(deferErr)
- }
- }()
-
- major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
- if err != nil {
- log.Error(err)
- }
- minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
- if err != nil {
- log.Error(err)
- }
- build, _, err := k.GetStringValue("CurrentBuildNumber")
- if err != nil {
- log.Error(err)
- }
- // Update Build Revision
- ubr, _, err := k.GetIntegerValue("UBR")
- if err != nil {
- log.Error(err)
- }
- ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
- return ver
-}
-
-func sysNumber() (string, error) {
- var dst []Win32_BIOS
- query := wmi.CreateQuery(&dst, "")
- err := wmi.Query(query, &dst)
- if err != nil {
- return "", err
- }
- return dst[0].SerialNumber, nil
-}
-
-func sysProductName() (string, error) {
- var dst []Win32_ComputerSystemProduct
- query := wmi.CreateQuery(&dst, "")
- err := wmi.Query(query, &dst)
- if err != nil {
- return "", err
- }
- // `ComputerSystemProduct` could be empty on some virtualized systems
- if len(dst) < 1 {
- return "unknown", nil
- }
- return dst[0].Name, nil
-}
-
-func sysManufacturer() (string, error) {
- var dst []Win32_ComputerSystem
- query := wmi.CreateQuery(&dst, "")
- err := wmi.Query(query, &dst)
- if err != nil {
- return "", err
- }
- return dst[0].Manufacturer, nil
-}
diff --git a/client/system/static_info.go b/client/system/static_info.go
index f178ec932..12a2663a1 100644
--- a/client/system/static_info.go
+++ b/client/system/static_info.go
@@ -3,12 +3,7 @@
package system
import (
- "context"
"sync"
- "time"
-
- "github.com/netbirdio/netbird/client/system/detect_cloud"
- "github.com/netbirdio/netbird/client/system/detect_platform"
)
var (
@@ -16,25 +11,26 @@ var (
once sync.Once
)
-func updateStaticInfo() StaticInfo {
+// StaticInfo is an object that contains machine information that does not change
+type StaticInfo struct {
+ SystemSerialNumber string
+ SystemProductName string
+ SystemManufacturer string
+ Environment Environment
+
+ // Windows specific fields
+ OSName string
+ OSVersion string
+ BuildVersion string
+}
+
+func updateStaticInfo() {
once.Do(func() {
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- wg := sync.WaitGroup{}
- wg.Add(3)
- go func() {
- staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
- wg.Done()
- }()
- go func() {
- staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
- wg.Done()
- }()
- go func() {
- staticInfo.Environment.Platform = detect_platform.Detect(ctx)
- wg.Done()
- }()
- wg.Wait()
+ staticInfo = newStaticInfo()
})
+}
+
+func getStaticInfo() StaticInfo {
+ updateStaticInfo()
return staticInfo
}
diff --git a/client/system/static_info_stub.go b/client/system/static_info_stub.go
deleted file mode 100644
index faa3e700b..000000000
--- a/client/system/static_info_stub.go
+++ /dev/null
@@ -1,8 +0,0 @@
-//go:build android || freebsd || ios
-
-package system
-
-// updateStaticInfo returns an empty implementation for unsupported platforms
-func updateStaticInfo() StaticInfo {
- return StaticInfo{}
-}
diff --git a/client/system/static_info_update.go b/client/system/static_info_update.go
new file mode 100644
index 000000000..af8b1e266
--- /dev/null
+++ b/client/system/static_info_update.go
@@ -0,0 +1,35 @@
+//go:build (linux && !android) || (darwin && !ios)
+
+package system
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ "github.com/netbirdio/netbird/client/system/detect_cloud"
+ "github.com/netbirdio/netbird/client/system/detect_platform"
+)
+
+func newStaticInfo() StaticInfo {
+ si := StaticInfo{}
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ wg := sync.WaitGroup{}
+ wg.Add(3)
+ go func() {
+ si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo()
+ wg.Done()
+ }()
+ go func() {
+ si.Environment.Cloud = detect_cloud.Detect(ctx)
+ wg.Done()
+ }()
+ go func() {
+ si.Environment.Platform = detect_platform.Detect(ctx)
+ wg.Done()
+ }()
+ wg.Wait()
+ return si
+}
diff --git a/client/system/static_info_update_windows.go b/client/system/static_info_update_windows.go
new file mode 100644
index 000000000..5f232c1de
--- /dev/null
+++ b/client/system/static_info_update_windows.go
@@ -0,0 +1,184 @@
+package system
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/yusufpapurcu/wmi"
+ "golang.org/x/sys/windows/registry"
+
+ "github.com/netbirdio/netbird/client/system/detect_cloud"
+ "github.com/netbirdio/netbird/client/system/detect_platform"
+)
+
+type Win32_OperatingSystem struct {
+ Caption string
+}
+
+type Win32_ComputerSystem struct {
+ Manufacturer string
+}
+
+type Win32_ComputerSystemProduct struct {
+ Name string
+}
+
+type Win32_BIOS struct {
+ SerialNumber string
+}
+
+func newStaticInfo() StaticInfo {
+ si := StaticInfo{}
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo()
+ wg.Done()
+ }()
+ wg.Add(1)
+ go func() {
+ si.Environment.Cloud = detect_cloud.Detect(ctx)
+ wg.Done()
+ }()
+ wg.Add(1)
+ go func() {
+ si.Environment.Platform = detect_platform.Detect(ctx)
+ wg.Done()
+ }()
+ wg.Add(1)
+ go func() {
+ si.OSName, si.OSVersion = getOSNameAndVersion()
+ wg.Done()
+ }()
+ wg.Add(1)
+ go func() {
+ si.BuildVersion = getBuildVersion()
+ wg.Done()
+ }()
+ wg.Wait()
+ return si
+}
+
+func sysInfo() (serialNumber string, productName string, manufacturer string) {
+ var err error
+ serialNumber, err = sysNumber()
+ if err != nil {
+ log.Warnf("failed to get system serial number: %s", err)
+ }
+
+ productName, err = sysProductName()
+ if err != nil {
+ log.Warnf("failed to get system product name: %s", err)
+ }
+
+ manufacturer, err = sysManufacturer()
+ if err != nil {
+ log.Warnf("failed to get system manufacturer: %s", err)
+ }
+
+ return serialNumber, productName, manufacturer
+}
+
+func sysNumber() (string, error) {
+ var dst []Win32_BIOS
+ query := wmi.CreateQuery(&dst, "")
+ err := wmi.Query(query, &dst)
+ if err != nil {
+ return "", err
+ }
+ return dst[0].SerialNumber, nil
+}
+
+func sysProductName() (string, error) {
+ var dst []Win32_ComputerSystemProduct
+ query := wmi.CreateQuery(&dst, "")
+ err := wmi.Query(query, &dst)
+ if err != nil {
+ return "", err
+ }
+ // `ComputerSystemProduct` could be empty on some virtualized systems
+ if len(dst) < 1 {
+ return "unknown", nil
+ }
+ return dst[0].Name, nil
+}
+
+func sysManufacturer() (string, error) {
+ var dst []Win32_ComputerSystem
+ query := wmi.CreateQuery(&dst, "")
+ err := wmi.Query(query, &dst)
+ if err != nil {
+ return "", err
+ }
+ return dst[0].Manufacturer, nil
+}
+
+func getOSNameAndVersion() (string, string) {
+ var dst []Win32_OperatingSystem
+ query := wmi.CreateQuery(&dst, "")
+ err := wmi.Query(query, &dst)
+ if err != nil {
+ log.Error(err)
+ return "Windows", getBuildVersion()
+ }
+
+ if len(dst) == 0 {
+ return "Windows", getBuildVersion()
+ }
+
+ split := strings.Split(dst[0].Caption, " ")
+
+ if len(split) <= 3 {
+ return "Windows", getBuildVersion()
+ }
+
+ name := split[1]
+ version := split[2]
+ if split[2] == "Server" {
+ name = fmt.Sprintf("%s %s", split[1], split[2])
+ version = split[3]
+ }
+
+ return name, version
+}
+
+func getBuildVersion() string {
+ k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
+ if err != nil {
+ log.Error(err)
+ return "0.0.0.0"
+ }
+ defer func() {
+ deferErr := k.Close()
+ if deferErr != nil {
+ log.Error(deferErr)
+ }
+ }()
+
+ major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
+ if err != nil {
+ log.Error(err)
+ }
+ minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
+ if err != nil {
+ log.Error(err)
+ }
+ build, _, err := k.GetStringValue("CurrentBuildNumber")
+ if err != nil {
+ log.Error(err)
+ }
+ // Update Build Revision
+ ubr, _, err := k.GetIntegerValue("UBR")
+ if err != nil {
+ log.Error(err)
+ }
+ ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
+ return ver
+}
diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go
index 78dd696db..533bf23d3 100644
--- a/client/ui/client_ui.go
+++ b/client/ui/client_ui.go
@@ -469,6 +469,17 @@ func (s *serviceClient) getConnectionForm() *widget.Form {
}
func (s *serviceClient) saveSettings() {
+ // Check if update settings are disabled by daemon
+ features, err := s.getFeatures()
+ if err != nil {
+ log.Errorf("failed to get features from daemon: %v", err)
+ // Continue with default behavior if features can't be retrieved
+ } else if features != nil && features.DisableUpdateSettings {
+ log.Warn("Configuration updates are disabled by daemon")
+ dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
+ return
+ }
+
if err := s.validateSettings(); err != nil {
dialog.ShowError(err, s.wSettings)
return
@@ -605,6 +616,28 @@ func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error {
return fmt.Errorf("set config: %w", err)
}
+ // Reconnect if connected to apply the new settings
+ go func() {
+ status, err := conn.Status(s.ctx, &proto.StatusRequest{})
+ if err != nil {
+ log.Errorf("get service status: %v", err)
+ return
+ }
+ if status.Status == string(internal.StatusConnected) {
+ // run down & up
+ _, err = conn.Down(s.ctx, &proto.DownRequest{})
+ if err != nil {
+ log.Errorf("down service: %v", err)
+ }
+
+ _, err = conn.Up(s.ctx, &proto.UpRequest{})
+ if err != nil {
+ log.Errorf("up service: %v", err)
+ return
+ }
+ }
+ }()
+
return nil
}
@@ -637,7 +670,7 @@ func (s *serviceClient) getNetworkForm() *widget.Form {
{Text: "Disable DNS", Widget: s.sDisableDNS},
{Text: "Disable Client Routes", Widget: s.sDisableClientRoutes},
{Text: "Disable Server Routes", Widget: s.sDisableServerRoutes},
- {Text: "Block LAN Access", Widget: s.sBlockLANAccess},
+ {Text: "Disable LAN Access", Widget: s.sBlockLANAccess},
},
}
}
diff --git a/flow/client/client.go b/flow/client/client.go
index 949824065..603fd6882 100644
--- a/flow/client/client.go
+++ b/flow/client/client.go
@@ -20,9 +20,9 @@ import (
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
+ nbgrpc "github.com/netbirdio/netbird/client/grpc"
"github.com/netbirdio/netbird/flow/proto"
"github.com/netbirdio/netbird/util/embeddedroots"
- nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
type GRPCClient struct {
diff --git a/go.mod b/go.mod
index 4b9064dbc..f135b5fc6 100644
--- a/go.mod
+++ b/go.mod
@@ -12,19 +12,18 @@ require (
github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.27.6
- github.com/pion/ice/v3 v3.0.2
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.3.0
- golang.org/x/crypto v0.39.0
- golang.org/x/sys v0.33.0
+ golang.org/x/crypto v0.40.0
+ golang.org/x/sys v0.34.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
- google.golang.org/grpc v1.64.1
- google.golang.org/protobuf v1.36.6
+ google.golang.org/grpc v1.73.0
+ google.golang.org/protobuf v1.36.8
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -63,16 +62,18 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
- github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190
+ github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203
- github.com/pion/logging v0.2.2
+ github.com/pion/ice/v4 v4.0.0-00010101000000-000000000000
+ github.com/pion/logging v0.2.4
github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0
- github.com/pion/transport/v3 v3.0.1
+ github.com/pion/stun/v3 v3.0.0
+ github.com/pion/transport/v3 v3.0.7
github.com/pion/turn/v3 v3.0.1
github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.22.0
@@ -94,18 +95,18 @@ require (
github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.1.3
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
- go.opentelemetry.io/otel v1.26.0
+ go.opentelemetry.io/otel v1.35.0
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
- go.opentelemetry.io/otel/metric v1.26.0
- go.opentelemetry.io/otel/sdk/metric v1.26.0
+ go.opentelemetry.io/otel/metric v1.35.0
+ go.opentelemetry.io/otel/sdk/metric v1.35.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
- golang.org/x/net v0.40.0
- golang.org/x/oauth2 v0.27.0
- golang.org/x/sync v0.15.0
- golang.org/x/term v0.32.0
+ golang.org/x/net v0.42.0
+ golang.org/x/oauth2 v0.28.0
+ golang.org/x/sync v0.16.0
+ golang.org/x/term v0.33.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
@@ -118,7 +119,7 @@ require (
require (
cloud.google.com/go/auth v0.3.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
- cloud.google.com/go/compute/metadata v0.3.0 // indirect
+ cloud.google.com/go/compute/metadata v0.6.0 // indirect
dario.cat/mergo v1.0.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
@@ -214,8 +215,10 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
- github.com/pion/mdns v0.0.12 // indirect
+ github.com/pion/dtls/v3 v3.0.7 // indirect
+ github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
+ github.com/pion/turn/v4 v4.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
@@ -231,22 +234,23 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
+ github.com/wlynxg/anet v0.0.3 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect
+ go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
- go.opentelemetry.io/otel/sdk v1.26.0 // indirect
- go.opentelemetry.io/otel/trace v1.26.0 // indirect
+ go.opentelemetry.io/otel/sdk v1.35.0 // indirect
+ go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.25.0 // indirect
- golang.org/x/text v0.26.0 // indirect
+ golang.org/x/text v0.27.0 // indirect
golang.org/x/time v0.5.0 // indirect
- golang.org/x/tools v0.33.0 // indirect
+ golang.org/x/tools v0.34.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
- google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
- google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
+ google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
)
@@ -259,6 +263,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
-replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e
+replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
diff --git a/go.sum b/go.sum
index f3a9a1788..f2c51bf67 100644
--- a/go.sum
+++ b/go.sum
@@ -29,8 +29,8 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
-cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
-cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
+cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
+cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk=
@@ -502,10 +502,10 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE
github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
-github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
-github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
-github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 h1:/ZbExdcDwRq6XgTpTf5I1DPqnC3eInEf0fcmkqR8eSg=
-github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
+github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
+github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
+github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
+github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
@@ -547,21 +547,29 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
-github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
+github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
+github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
-github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8=
-github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk=
+github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
+github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
+github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
+github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
+github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
+github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
-github.com/pion/transport/v3 v3.0.1 h1:gDTlPJwROfSfz6QfSi0ZmeCSkFcnWWiiR9ES0ouANiM=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
+github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
+github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
+github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
+github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -591,8 +599,8 @@ github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
-github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
-github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
+github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
+github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
@@ -684,6 +692,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
+github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
+github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -715,26 +725,28 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
+go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
+go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
-go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs=
-go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4=
+go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
+go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
-go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30=
-go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4=
-go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8=
-go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs=
-go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y=
-go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE=
-go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA=
-go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0=
+go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
+go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
+go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
+go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
+go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
+go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
+go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
+go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
@@ -766,8 +778,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
-golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
-golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
+golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
+golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -866,8 +878,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
-golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
-golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
+golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
+golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -881,8 +893,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
-golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M=
-golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
+golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
+golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -900,8 +912,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
-golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
+golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -973,8 +985,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
-golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
+golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -987,8 +999,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
-golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
-golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
+golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
+golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1005,8 +1017,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
-golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
-golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
+golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
+golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1071,8 +1083,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
-golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
-golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
+golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
+golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -1155,10 +1167,11 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D
google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A=
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
-google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U=
-google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
+google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
+google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM=
+google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
@@ -1179,8 +1192,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG
google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
-google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
-google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
+google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
+google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
@@ -1195,11 +1208,10 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
-google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
-google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
+google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
+google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh
index 2d7c65cbe..cfec1000e 100644
--- a/infrastructure_files/getting-started-with-zitadel.sh
+++ b/infrastructure_files/getting-started-with-zitadel.sh
@@ -328,6 +328,45 @@ delete_auto_service_user() {
echo "$PARSED_RESPONSE"
}
+delete_default_zitadel_admin() {
+ INSTANCE_URL=$1
+ PAT=$2
+
+ # Search for the default zitadel-admin user
+ RESPONSE=$(
+ curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \
+ -H "Authorization: Bearer $PAT" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "queries": [
+ {
+ "userNameQuery": {
+ "userName": "zitadel-admin@",
+ "method": "TEXT_QUERY_METHOD_STARTS_WITH"
+ }
+ }
+ ]
+ }'
+ )
+
+ DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty')
+
+ if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then
+ echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID"
+
+ RESPONSE=$(
+ curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \
+ -H "Authorization: Bearer $PAT" \
+ -H "Content-Type: application/json" \
+ )
+ PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"')
+ handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE"
+
+ else
+ echo "Default zitadel-admin user not found: $RESPONSE"
+ fi
+}
+
init_zitadel() {
echo -e "\nInitializing Zitadel with NetBird's applications\n"
INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
@@ -346,6 +385,9 @@ init_zitadel() {
echo -n "Waiting for Zitadel to become ready "
wait_api "$INSTANCE_URL" "$PAT"
+ echo "Deleting default zitadel-admin user..."
+ delete_default_zitadel_admin "$INSTANCE_URL" "$PAT"
+
# create the zitadel project
echo "Creating new zitadel project"
PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT")
diff --git a/management/README.md b/management/README.md
index 1122a9e76..c70285d43 100644
--- a/management/README.md
+++ b/management/README.md
@@ -111,3 +111,6 @@ Generate gRpc code:
#!/bin/bash
protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=.
```
+
+
+
diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go
index b351f3bc9..984a56a39 100644
--- a/management/internals/server/controllers.go
+++ b/management/internals/server/controllers.go
@@ -20,7 +20,11 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
- integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
+ integratedPeerValidator, err := integrations.NewIntegratedValidator(
+ context.Background(),
+ s.PeersManager(),
+ s.SettingsManager(),
+ s.EventStore())
if err != nil {
log.Errorf("failed to create integrated peer validator: %v", err)
}
diff --git a/management/server/account.go b/management/server/account.go
index f217eadb3..ee9f294a4 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -104,6 +104,8 @@ type DefaultAccountManager struct {
accountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
+ loginFilter *loginFilter
+
disableDefaultPolicy bool
}
@@ -211,6 +213,7 @@ func BuildManager(
proxyController: proxyController,
settingsManager: settingsManager,
permissionsManager: permissionsManager,
+ loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy,
}
@@ -1133,7 +1136,18 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID
- err := am.Store.SaveUser(ctx, newUser)
+
+ settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID)
+ if err != nil {
+ return "", err
+ }
+
+ if settings != nil && settings.Extra != nil && settings.Extra.UserApprovalRequired {
+ newUser.Blocked = true
+ newUser.PendingApproval = true
+ }
+
+ err = am.Store.SaveUser(ctx, newUser)
if err != nil {
return "", err
}
@@ -1143,7 +1157,11 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
return "", err
}
- am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
+ if newUser.PendingApproval {
+ am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
+ } else {
+ am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
+ }
return domainAccountID, nil
}
@@ -1612,6 +1630,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
+func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
+ return am.loginFilter.allowLogin(wgPubKey, metahash)
+}
+
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
@@ -1628,6 +1650,9 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
+ metahash := metaHash(meta, realIP.String())
+ am.loginFilter.addLogin(peerPubKey, metahash)
+
return peer, netMap, postureChecks, nil
}
@@ -1636,7 +1661,6 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
-
return nil
}
@@ -1690,7 +1714,9 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account
log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err)
continue
}
- peers = append(peers, peer)
+ if peer.UserID != "" {
+ peers = append(peers, peer)
+ }
}
if len(peers) > 0 {
err := am.expireAndUpdatePeers(ctx, accountID, peers)
@@ -1786,6 +1812,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true,
+ Extra: &types.ExtraSettings{
+ UserApprovalRequired: true,
+ },
},
Onboarding: types.AccountOnboarding{
OnboardingFlowPending: true,
@@ -1892,6 +1921,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true,
+ Extra: &types.ExtraSettings{
+ UserApprovalRequired: true,
+ },
},
}
diff --git a/management/server/account/manager.go b/management/server/account/manager.go
index c7a39004a..30fbbbc3e 100644
--- a/management/server/account/manager.go
+++ b/management/server/account/manager.go
@@ -32,6 +32,8 @@ type Manager interface {
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
+ ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
+ RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error)
@@ -123,4 +125,5 @@ type Manager interface {
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
+ AllowSync(string, uint64) bool
}
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 252be23f7..81a921bf9 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -15,6 +15,7 @@ import (
"time"
"github.com/golang/mock/gomock"
+ "github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -25,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -3046,19 +3048,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
- minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
- minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
+ testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "sync", "syncAndMark")
}
- if msPerOp < minExpected {
- b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
- }
-
- if msPerOp > (maxExpected * 1.1) {
- b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
+ if msPerOp > maxExpected {
+ b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
@@ -3121,19 +3118,14 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
- minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
- minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
+ testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "existingPeer")
}
- if msPerOp < minExpected {
- b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
- }
-
- if msPerOp > (maxExpected * 1.1) {
- b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
+ if msPerOp > maxExpected {
+ b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
@@ -3196,24 +3188,44 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
- minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
- minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
+ testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
}
- if msPerOp < minExpected {
- b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
- }
-
- if msPerOp > (maxExpected * 1.1) {
- b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
+ if msPerOp > maxExpected {
+ b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
+func TestMain(m *testing.M) {
+ exitCode := m.Run()
+
+ if exitCode == 0 && os.Getenv("CI") == "true" {
+ runID := os.Getenv("GITHUB_RUN_ID")
+ storeEngine := os.Getenv("NETBIRD_STORE_ENGINE")
+ err := push.New("http://localhost:9091", "account_manager_benchmark").
+ Collector(testing_tools.BenchmarkDuration).
+ Grouping("ci_run", runID).
+ Grouping("store_engine", storeEngine).
+ Push()
+ if err != nil {
+ log.Printf("Failed to push metrics: %v", err)
+ } else {
+ time.Sleep(1 * time.Minute)
+ _ = push.New("http://localhost:9091", "account_manager_benchmark").
+ Grouping("ci_run", runID).
+ Grouping("store_engine", storeEngine).
+ Delete()
+ }
+ }
+
+ os.Exit(exitCode)
+}
+
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@@ -3594,3 +3606,93 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
require.Error(t, err, "should fail with invalid peer ID")
})
}
+
+func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a domain-based account with user approval enabled
+ existingAccountID := "existing-account"
+ account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false)
+ account.Settings.Extra = &types.ExtraSettings{
+ UserApprovalRequired: true,
+ }
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Set the account as domain primary account
+ account.IsDomainPrimaryAccount = true
+ account.DomainCategory = types.PrivateCategory
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Test adding new user to existing account with approval required
+ newUserID := "new-user-id"
+ userAuth := nbcontext.UserAuth{
+ UserId: newUserID,
+ Domain: "example.com",
+ DomainCategory: types.PrivateCategory,
+ }
+
+ acc, err := manager.Store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+ require.True(t, acc.IsDomainPrimaryAccount, "Account should be primary for the domain")
+ require.Equal(t, "example.com", acc.Domain, "Account domain should match")
+
+ returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth)
+ require.NoError(t, err)
+ require.Equal(t, existingAccountID, returnedAccountID)
+
+ // Verify user was created with pending approval
+ user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID)
+ require.NoError(t, err)
+ assert.True(t, user.Blocked, "User should be blocked when approval is required")
+ assert.True(t, user.PendingApproval, "User should be pending approval")
+ assert.Equal(t, existingAccountID, user.AccountID)
+}
+
+func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a domain-based account without user approval
+ ownerUserAuth := nbcontext.UserAuth{
+ UserId: "owner-user",
+ Domain: "example.com",
+ DomainCategory: types.PrivateCategory,
+ }
+ existingAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), ownerUserAuth)
+ require.NoError(t, err)
+
+ // Modify the account to disable user approval
+ account, err := manager.Store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+ account.Settings.Extra = &types.ExtraSettings{
+ UserApprovalRequired: false,
+ }
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Test adding new user to existing account without approval required
+ newUserID := "new-user-id"
+ userAuth := nbcontext.UserAuth{
+ UserId: newUserID,
+ Domain: "example.com",
+ DomainCategory: types.PrivateCategory,
+ }
+
+ returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth)
+ require.NoError(t, err)
+ require.Equal(t, existingAccountID, returnedAccountID)
+
+ // Verify user was created without pending approval
+ user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID)
+ require.NoError(t, err)
+ assert.False(t, user.Blocked, "User should not be blocked when approval is not required")
+ assert.False(t, user.PendingApproval, "User should not be pending approval")
+ assert.Equal(t, existingAccountID, user.AccountID)
+}
diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go
index 6f9619597..5c5989f84 100644
--- a/management/server/activity/codes.go
+++ b/management/server/activity/codes.go
@@ -177,6 +177,8 @@ const (
AccountNetworkRangeUpdated Activity = 87
PeerIPUpdated Activity = 88
+ UserApproved Activity = 89
+ UserRejected Activity = 90
AccountDeleted Activity = 99999
)
@@ -284,6 +286,8 @@ var activityMap = map[Activity]Code{
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
+ UserApproved: {"User approved", "user.approve"},
+ UserRejected: {"User rejected", "user.reject"},
}
// StringCode returns a string code of the activity
diff --git a/management/server/group.go b/management/server/group.go
index 86bc0d8a0..487cb6d97 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -202,35 +202,45 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
- var groupsToSave []*types.Group
var updateAccountPeers bool
- err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
- groupIDs := make([]string, 0, len(groups))
- for _, newGroup := range groups {
+ var globalErr error
+ groupIDs := make([]string, 0, len(groups))
+ for _, newGroup := range groups {
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
- groupsToSave = append(groupsToSave, newGroup)
+
+ if err = transaction.CreateGroup(ctx, newGroup); err != nil {
+ return err
+ }
+
+ err = transaction.IncrementNetworkSerial(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
groupIDs = append(groupIDs, newGroup.ID)
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
- }
- updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
+ return nil
+ })
if err != nil {
- return err
+ log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
+ if len(groupIDs) == 1 {
+ return err
+ }
+ globalErr = errors.Join(globalErr, err)
+ // continue updating other groups
}
+ }
- if err = transaction.CreateGroups(ctx, accountID, groupsToSave); err != nil {
- return err
- }
-
- return transaction.IncrementNetworkSerial(ctx, accountID)
- })
+ updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
@@ -243,7 +253,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
am.UpdateAccountPeers(ctx, accountID)
}
- return nil
+ return globalErr
}
// UpdateGroups updates groups in the account.
@@ -260,35 +270,45 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
- var groupsToSave []*types.Group
var updateAccountPeers bool
- err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
- groupIDs := make([]string, 0, len(groups))
- for _, newGroup := range groups {
+ var globalErr error
+ groupIDs := make([]string, 0, len(groups))
+ for _, newGroup := range groups {
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
- groupsToSave = append(groupsToSave, newGroup)
- groupIDs = append(groupIDs, newGroup.ID)
+
+ if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
+ return err
+ }
+
+ err = transaction.IncrementNetworkSerial(ctx, accountID)
+ if err != nil {
+ return err
+ }
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
- }
- updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
+ groupIDs = append(groupIDs, newGroup.ID)
+
+ return nil
+ })
if err != nil {
- return err
+ log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
+ if len(groups) == 1 {
+ return err
+ }
+ globalErr = errors.Join(globalErr, err)
+ // continue updating other groups
}
+ }
- if err = transaction.UpdateGroups(ctx, accountID, groupsToSave); err != nil {
- return err
- }
-
- return transaction.IncrementNetworkSerial(ctx, accountID)
- })
+ updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
@@ -301,7 +321,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
am.UpdateAccountPeers(ctx, accountID)
}
- return nil
+ return globalErr
}
// prepareGroupEvents prepares a list of event functions to be stored.
@@ -584,13 +604,6 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
newGroup.ID = xid.New().String()
}
- for _, peerID := range newGroup.Peers {
- _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
- if err != nil {
- return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
- }
- }
-
return nil
}
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index a637cf02d..60a00207e 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -2,9 +2,11 @@ package server
import (
"context"
+ "errors"
"fmt"
"net"
"net/netip"
+ "os"
"strings"
"sync"
"time"
@@ -38,20 +40,28 @@ import (
internalStatus "github.com/netbirdio/netbird/shared/management/status"
)
+const (
+ envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
+ envBlockPeers = "NB_BLOCK_SAME_PEERS"
+)
+
// GRPCServer an instance of a Management gRPC API server
type GRPCServer struct {
accountManager account.Manager
settingsManager settings.Manager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
- peersUpdateManager *PeersUpdateManager
- config *nbconfig.Config
- secretsManager SecretsManager
- appMetrics telemetry.AppMetrics
- ephemeralManager *EphemeralManager
- peerLocks sync.Map
- authManager auth.Manager
- integratedPeerValidator integrated_validator.IntegratedValidator
+ peersUpdateManager *PeersUpdateManager
+ config *nbconfig.Config
+ secretsManager SecretsManager
+ appMetrics telemetry.AppMetrics
+ ephemeralManager *EphemeralManager
+ peerLocks sync.Map
+ authManager auth.Manager
+
+ logBlockedPeers bool
+ blockPeersWithSameConfig bool
+ integratedPeerValidator integrated_validator.IntegratedValidator
}
// NewServer creates a new Management server
@@ -82,18 +92,23 @@ func NewServer(
}
}
+ logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
+ blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
+
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
- peersUpdateManager: peersUpdateManager,
- accountManager: accountManager,
- settingsManager: settingsManager,
- config: config,
- secretsManager: secretsManager,
- authManager: authManager,
- appMetrics: appMetrics,
- ephemeralManager: ephemeralManager,
- integratedPeerValidator: integratedPeerValidator,
+ peersUpdateManager: peersUpdateManager,
+ accountManager: accountManager,
+ settingsManager: settingsManager,
+ config: config,
+ secretsManager: secretsManager,
+ authManager: authManager,
+ appMetrics: appMetrics,
+ ephemeralManager: ephemeralManager,
+ logBlockedPeers: logBlockedPeers,
+ blockPeersWithSameConfig: blockPeersWithSameConfig,
+ integratedPeerValidator: integratedPeerValidator,
}, nil
}
@@ -136,9 +151,6 @@ func getRealIP(ctx context.Context) net.IP {
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
reqStart := time.Now()
- if s.appMetrics != nil {
- s.appMetrics.GRPCMetrics().CountSyncRequest()
- }
ctx := srv.Context()
@@ -147,6 +159,25 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if err != nil {
return err
}
+ realIP := getRealIP(ctx)
+ sRealIP := realIP.String()
+ peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
+ metahashed := metaHash(peerMeta, sRealIP)
+ if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
+ if s.appMetrics != nil {
+ s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
+ }
+ if s.logBlockedPeers {
+ log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
+ }
+ if s.blockPeersWithSameConfig {
+ return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
+ }
+ }
+
+ if s.appMetrics != nil {
+ s.appMetrics.GRPCMetrics().CountSyncRequest()
+ }
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
@@ -172,14 +203,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
- realIP := getRealIP(ctx)
- log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
+ log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
if syncReq.GetMeta() == nil {
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
- peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
+ peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
return mapError(ctx, err)
@@ -198,7 +228,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
if s.appMetrics != nil {
- s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
+ s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
}
unlock()
@@ -228,6 +258,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
+ log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
@@ -345,6 +376,9 @@ func mapError(ctx context.Context, err error) error {
default:
}
}
+ if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) {
+ return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error())
+ }
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request")
}
@@ -436,16 +470,9 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
// In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
reqStart := time.Now()
- defer func() {
- if s.appMetrics != nil {
- s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
- }
- }()
- if s.appMetrics != nil {
- s.appMetrics.GRPCMetrics().CountLoginRequest()
- }
realIP := getRealIP(ctx)
- log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
+ sRealIP := realIP.String()
+ log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -453,6 +480,24 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, err
}
+ peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
+ metahashed := metaHash(peerMeta, sRealIP)
+ if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
+ if s.logBlockedPeers {
+ log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
+ }
+ if s.appMetrics != nil {
+ s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
+ }
+ if s.blockPeersWithSameConfig {
+ return nil, internalStatus.ErrPeerAlreadyLoggedIn
+ }
+ }
+
+ if s.appMetrics != nil {
+ s.appMetrics.GRPCMetrics().CountLoginRequest()
+ }
+
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
@@ -463,6 +508,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
+ defer func() {
+ if s.appMetrics != nil {
+ s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
+ }
+ }()
+
if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
@@ -483,7 +534,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
- Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
+ Meta: peerMeta,
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
@@ -949,8 +1000,6 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
return nil, mapError(ctx, err)
}
- s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
-
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
return &proto.Empty{}, nil
diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go
index 9f2afe29d..f1552d0ea 100644
--- a/management/server/http/handlers/accounts/accounts_handler.go
+++ b/management/server/http/handlers/accounts/accounts_handler.go
@@ -11,11 +11,11 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/settings"
+ "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
- "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
- "github.com/netbirdio/netbird/management/server/types"
)
const (
@@ -198,6 +198,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
if req.Settings.Extra != nil {
settings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
+ UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
@@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
if settings.Extra != nil {
apiSettings.Extra = &api.AccountExtraSettings{
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
+ UserApprovalRequired: settings.Extra.UserApprovalRequired,
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go
index 1dad33a6f..4b9b79fdc 100644
--- a/management/server/http/handlers/accounts/accounts_handler_test.go
+++ b/management/server/http/handlers/accounts/accounts_handler_test.go
@@ -15,11 +15,11 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/status"
)
func initAccountsTestData(t *testing.T, account *types.Account) *handler {
diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go
index 414c7b1b9..af501e151 100644
--- a/management/server/http/handlers/peers/peers_handler.go
+++ b/management/server/http/handlers/peers/peers_handler.go
@@ -14,11 +14,11 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
- nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
- "github.com/netbirdio/netbird/management/server/types"
)
// Handler is a handler that returns peers of the account
@@ -354,7 +354,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
}
return &api.Peer{
- CreatedAt: peer.CreatedAt,
+ CreatedAt: peer.CreatedAt,
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
@@ -391,33 +391,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
}
return &api.PeerBatch{
- CreatedAt: peer.CreatedAt,
- Id: peer.ID,
- Name: peer.Name,
- Ip: peer.IP.String(),
- ConnectionIp: peer.Location.ConnectionIP.String(),
- Connected: peer.Status.Connected,
- LastSeen: peer.Status.LastSeen,
- Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
- KernelVersion: peer.Meta.KernelVersion,
- GeonameId: int(peer.Location.GeoNameID),
- Version: peer.Meta.WtVersion,
- Groups: groupsInfo,
- SshEnabled: peer.SSHEnabled,
- Hostname: peer.Meta.Hostname,
- UserId: peer.UserID,
- UiVersion: peer.Meta.UIVersion,
- DnsLabel: fqdn(peer, dnsDomain),
- ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
- LoginExpirationEnabled: peer.LoginExpirationEnabled,
- LastLogin: peer.GetLastLogin(),
- LoginExpired: peer.Status.LoginExpired,
- AccessiblePeersCount: accessiblePeersCount,
- CountryCode: peer.Location.CountryCode,
- CityName: peer.Location.CityName,
- SerialNumber: peer.Meta.SystemSerialNumber,
-
+ CreatedAt: peer.CreatedAt,
+ Id: peer.ID,
+ Name: peer.Name,
+ Ip: peer.IP.String(),
+ ConnectionIp: peer.Location.ConnectionIP.String(),
+ Connected: peer.Status.Connected,
+ LastSeen: peer.Status.LastSeen,
+ Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
+ KernelVersion: peer.Meta.KernelVersion,
+ GeonameId: int(peer.Location.GeoNameID),
+ Version: peer.Meta.WtVersion,
+ Groups: groupsInfo,
+ SshEnabled: peer.SSHEnabled,
+ Hostname: peer.Meta.Hostname,
+ UserId: peer.UserID,
+ UiVersion: peer.Meta.UIVersion,
+ DnsLabel: fqdn(peer, dnsDomain),
+ ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
+ LoginExpirationEnabled: peer.LoginExpirationEnabled,
+ LastLogin: peer.GetLastLogin(),
+ LoginExpired: peer.Status.LoginExpired,
+ AccessiblePeersCount: accessiblePeersCount,
+ CountryCode: peer.Location.CountryCode,
+ CityName: peer.Location.CityName,
+ SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
+ Ephemeral: peer.Ephemeral,
}
}
diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go
index bcd637db4..4e03e5e9b 100644
--- a/management/server/http/handlers/users/users_handler.go
+++ b/management/server/http/handlers/users/users_handler.go
@@ -9,11 +9,11 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
+ "github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
- "github.com/netbirdio/netbird/management/server/types"
- "github.com/netbirdio/netbird/management/server/users"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
@@ -31,6 +31,8 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
+ router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
+ router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
addUsersTokensEndpoint(accountManager, router)
}
@@ -323,17 +325,76 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
}
isCurrent := user.ID == currenUserID
+
return &api.User{
- Id: user.ID,
- Name: user.Name,
- Email: user.Email,
- Role: user.Role,
- AutoGroups: autoGroups,
- Status: userStatus,
- IsCurrent: &isCurrent,
- IsServiceUser: &user.IsServiceUser,
- IsBlocked: user.IsBlocked,
- LastLogin: &user.LastLogin,
- Issued: &user.Issued,
+ Id: user.ID,
+ Name: user.Name,
+ Email: user.Email,
+ Role: user.Role,
+ AutoGroups: autoGroups,
+ Status: userStatus,
+ IsCurrent: &isCurrent,
+ IsServiceUser: &user.IsServiceUser,
+ IsBlocked: user.IsBlocked,
+ LastLogin: &user.LastLogin,
+ Issued: &user.Issued,
+ PendingApproval: user.PendingApproval,
}
}
+
+// approveUser is a POST request to approve a user that is pending approval
+func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
+ return
+ }
+
+ vars := mux.Vars(r)
+ targetUserID := vars["userId"]
+ if len(targetUserID) == 0 {
+ util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
+ return
+ }
+
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+ user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ userResponse := toUserResponse(user, userAuth.UserId)
+ util.WriteJSONObject(r.Context(), w, userResponse)
+}
+
+// rejectUser is a DELETE request to reject a user that is pending approval
+func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodDelete {
+ util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
+ return
+ }
+
+ vars := mux.Vars(r)
+ targetUserID := vars["userId"]
+ if len(targetUserID) == 0 {
+ util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
+ return
+ }
+
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+ err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
+}
diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go
index f7dc81919..e08004218 100644
--- a/management/server/http/handlers/users/users_handler_test.go
+++ b/management/server/http/handlers/users/users_handler_test.go
@@ -16,13 +16,13 @@ import (
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
+ "github.com/netbirdio/netbird/shared/management/http/api"
+ "github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -725,3 +725,133 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri
}
return modules
}
+
+func TestApproveUserEndpoint(t *testing.T) {
+ adminUser := &types.User{
+ Id: "admin-user",
+ Role: types.UserRoleAdmin,
+ AccountID: existingAccountID,
+ AutoGroups: []string{},
+ }
+
+ pendingUser := &types.User{
+ Id: "pending-user",
+ Role: types.UserRoleUser,
+ AccountID: existingAccountID,
+ Blocked: true,
+ PendingApproval: true,
+ AutoGroups: []string{},
+ }
+
+ tt := []struct {
+ name string
+ expectedStatus int
+ expectedBody bool
+ requestingUser *types.User
+ }{
+ {
+ name: "approve user as admin should return 200",
+ expectedStatus: 200,
+ expectedBody: true,
+ requestingUser: adminUser,
+ },
+ }
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+ am := &mock_server.MockAccountManager{}
+ am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
+ approvedUserInfo := &types.UserInfo{
+ ID: pendingUser.Id,
+ Email: "pending@example.com",
+ Name: "Pending User",
+ Role: string(pendingUser.Role),
+ AutoGroups: []string{},
+ IsServiceUser: false,
+ IsBlocked: false,
+ PendingApproval: false,
+ LastLogin: time.Now(),
+ Issued: types.UserIssuedAPI,
+ }
+ return approvedUserInfo, nil
+ }
+
+ handler := newHandler(am)
+ router := mux.NewRouter()
+ router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
+
+ req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
+ require.NoError(t, err)
+
+ userAuth := nbcontext.UserAuth{
+ AccountId: existingAccountID,
+ UserId: tc.requestingUser.Id,
+ }
+ ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
+ req = req.WithContext(ctx)
+
+ rr := httptest.NewRecorder()
+ router.ServeHTTP(rr, req)
+
+ assert.Equal(t, tc.expectedStatus, rr.Code)
+
+ if tc.expectedBody {
+ var response api.User
+ err = json.Unmarshal(rr.Body.Bytes(), &response)
+ require.NoError(t, err)
+ assert.Equal(t, "pending-user", response.Id)
+ assert.False(t, response.IsBlocked)
+ assert.False(t, response.PendingApproval)
+ }
+ })
+ }
+}
+
+func TestRejectUserEndpoint(t *testing.T) {
+ adminUser := &types.User{
+ Id: "admin-user",
+ Role: types.UserRoleAdmin,
+ AccountID: existingAccountID,
+ AutoGroups: []string{},
+ }
+
+ tt := []struct {
+ name string
+ expectedStatus int
+ requestingUser *types.User
+ }{
+ {
+ name: "reject user as admin should return 200",
+ expectedStatus: 200,
+ requestingUser: adminUser,
+ },
+ }
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+ am := &mock_server.MockAccountManager{}
+ am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
+ return nil
+ }
+
+ handler := newHandler(am)
+ router := mux.NewRouter()
+ router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
+
+ req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
+ require.NoError(t, err)
+
+ userAuth := nbcontext.UserAuth{
+ AccountId: existingAccountID,
+ UserId: tc.requestingUser.Id,
+ }
+ ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
+ req = req.WithContext(ctx)
+
+ rr := httptest.NewRecorder()
+ router.ServeHTTP(rr, req)
+
+ assert.Equal(t, tc.expectedStatus, rr.Code)
+ })
+ }
+}
diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go
index 52737e4eb..3fe3fe809 100644
--- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go
+++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go
@@ -17,8 +17,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
+ "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
+ "github.com/netbirdio/netbird/shared/management/http/api"
)
const modulePeers = "peers"
@@ -47,7 +48,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -65,7 +66,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
})
}
}
@@ -82,7 +83,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -92,7 +93,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
})
}
}
@@ -109,7 +110,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -119,7 +120,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
})
}
}
@@ -136,7 +137,7 @@ func BenchmarkDeletePeer(b *testing.B) {
for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -146,7 +147,7 @@ func BenchmarkDeletePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
})
}
}
diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go
index 9404c4ee4..36b226db0 100644
--- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go
+++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go
@@ -17,8 +17,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
+ "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
+ "github.com/netbirdio/netbird/shared/management/http/api"
)
// Map to store peers, groups, users, and setupKeys by name
@@ -47,7 +48,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -69,7 +70,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
})
}
}
@@ -86,7 +87,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -109,7 +110,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
})
}
}
@@ -126,7 +127,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -136,7 +137,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
})
}
}
@@ -153,7 +154,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -163,7 +164,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
})
}
}
@@ -180,7 +181,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000)
b.ResetTimer()
@@ -190,7 +191,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
})
}
}
diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go
index 844b3e7a6..2868a20bd 100644
--- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go
+++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go
@@ -18,8 +18,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
+ "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
+ "github.com/netbirdio/netbird/shared/management/http/api"
)
const moduleUsers = "users"
@@ -46,7 +47,7 @@ func BenchmarkUpdateUser(b *testing.B) {
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
@@ -71,7 +72,7 @@ func BenchmarkUpdateUser(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
})
}
}
@@ -84,18 +85,18 @@ func BenchmarkGetOneUser(b *testing.B) {
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
+ req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
for i := 0; i < b.N; i++ {
- req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
})
}
}
@@ -110,18 +111,18 @@ func BenchmarkGetAllUsers(b *testing.B) {
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
+ req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
for i := 0; i < b.N; i++ {
- req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
})
}
}
@@ -136,7 +137,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
- apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
+ apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys)
recorder := httptest.NewRecorder()
@@ -147,7 +148,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
- testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
+ testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
})
}
}
diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go
index 9f04e3c24..1079de4aa 100644
--- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go
+++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go
@@ -15,9 +15,10 @@ import (
"github.com/stretchr/testify/assert"
- "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
+ "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
+ "github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_SetupKeys_Create(t *testing.T) {
@@ -287,7 +288,7 @@ func Test_SetupKeys_Create(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
- apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
+ apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
body, err := json.Marshal(tc.requestBody)
if err != nil {
@@ -572,7 +573,7 @@ func Test_SetupKeys_Update(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
- apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
+ apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
body, err := json.Marshal(tc.requestBody)
if err != nil {
@@ -751,7 +752,7 @@ func Test_SetupKeys_Get(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
- apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
+ apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)
@@ -903,7 +904,7 @@ func Test_SetupKeys_GetAll(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
- apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
+ apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId)
@@ -1087,7 +1088,7 @@ func Test_SetupKeys_Delete(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
- apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
+ apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)
diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go
new file mode 100644
index 000000000..741f03f18
--- /dev/null
+++ b/management/server/http/testing/testing_tools/channel/channel.go
@@ -0,0 +1,137 @@
+package channel
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/netbirdio/management-integrations/integrations"
+ "github.com/stretchr/testify/assert"
+
+ "github.com/netbirdio/netbird/management/server"
+ "github.com/netbirdio/netbird/management/server/account"
+ "github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/auth"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/geolocation"
+ "github.com/netbirdio/netbird/management/server/groups"
+ http2 "github.com/netbirdio/netbird/management/server/http"
+ "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
+ "github.com/netbirdio/netbird/management/server/networks"
+ "github.com/netbirdio/netbird/management/server/networks/resources"
+ "github.com/netbirdio/netbird/management/server/networks/routers"
+ "github.com/netbirdio/netbird/management/server/peers"
+ "github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/settings"
+ "github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/management/server/telemetry"
+ "github.com/netbirdio/netbird/management/server/users"
+)
+
+func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
+ store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
+ if err != nil {
+ t.Fatalf("Failed to create test store: %v", err)
+ }
+ t.Cleanup(cleanup)
+
+ metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
+ if err != nil {
+ t.Fatalf("Failed to create metrics: %v", err)
+ }
+
+ peersUpdateManager := server.NewPeersUpdateManager(nil)
+ updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
+ done := make(chan struct{})
+ if validateUpdate {
+ go func() {
+ if expectedPeerUpdate != nil {
+ peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
+ } else {
+ peerShouldNotReceiveUpdate(t, updMsg)
+ }
+ close(done)
+ }()
+ }
+
+ geoMock := &geolocation.Mock{}
+ validatorMock := server.MockIntegratedValidator{}
+ proxyController := integrations.NewController(store)
+ userManager := users.NewManager(store)
+ permissionsManager := permissions.NewManager(store)
+ settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
+ am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
+ if err != nil {
+ t.Fatalf("Failed to create manager: %v", err)
+ }
+
+ // @note this is required so that PAT's validate from store, but JWT's are mocked
+ authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
+ authManagerMock := &auth.MockManager{
+ ValidateAndParseTokenFunc: mockValidateAndParseToken,
+ EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
+ MarkPATUsedFunc: authManager.MarkPATUsed,
+ GetPATInfoFunc: authManager.GetPATInfo,
+ }
+
+ networksManagerMock := networks.NewManagerMock()
+ resourcesManagerMock := resources.NewManagerMock()
+ routersManagerMock := routers.NewManagerMock()
+ groupsManagerMock := groups.NewManagerMock()
+ peersManager := peers.NewManager(store, permissionsManager)
+
+ apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
+ if err != nil {
+ t.Fatalf("Failed to create API handler: %v", err)
+ }
+
+ return apiHandler, am, done
+}
+
+func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
+ t.Helper()
+ select {
+ case msg := <-updateMessage:
+ t.Errorf("Unexpected message received: %+v", msg)
+ case <-time.After(500 * time.Millisecond):
+ return
+ }
+}
+
+func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
+ t.Helper()
+
+ select {
+ case msg := <-updateMessage:
+ if msg == nil {
+ t.Errorf("Received nil update message, expected valid message")
+ }
+ assert.Equal(t, expected, msg)
+ case <-time.After(500 * time.Millisecond):
+ t.Errorf("Timed out waiting for update message")
+ }
+}
+
+func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
+ userAuth := nbcontext.UserAuth{}
+
+ switch token {
+ case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
+ userAuth.UserId = token
+ userAuth.AccountId = "testAccountId"
+ userAuth.Domain = "test.com"
+ userAuth.DomainCategory = "private"
+ case "otherUserId":
+ userAuth.UserId = "otherUserId"
+ userAuth.AccountId = "otherAccountId"
+ userAuth.Domain = "other.com"
+ userAuth.DomainCategory = "private"
+ case "invalidToken":
+ return userAuth, nil, errors.New("invalid token")
+ }
+
+ jwtToken := jwt.New(jwt.SigningMethodHS256)
+ return userAuth, jwtToken, nil
+}
diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go
index 1b82b156e..b7a63b104 100644
--- a/management/server/http/testing/testing_tools/tools.go
+++ b/management/server/http/testing/testing_tools/tools.go
@@ -3,7 +3,6 @@ package testing_tools
import (
"bytes"
"context"
- "errors"
"fmt"
"io"
"net"
@@ -14,32 +13,12 @@ import (
"testing"
"time"
- "github.com/golang-jwt/jwt/v5"
"github.com/prometheus/client_golang/prometheus"
- "github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/management-integrations/integrations"
- "github.com/netbirdio/netbird/management/server/peers"
- "github.com/netbirdio/netbird/management/server/permissions"
- "github.com/netbirdio/netbird/management/server/settings"
- "github.com/netbirdio/netbird/management/server/users"
-
- "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
- "github.com/netbirdio/netbird/management/server/activity"
- "github.com/netbirdio/netbird/management/server/auth"
- nbcontext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/management/server/geolocation"
- "github.com/netbirdio/netbird/management/server/groups"
- nbhttp "github.com/netbirdio/netbird/management/server/http"
- "github.com/netbirdio/netbird/management/server/networks"
- "github.com/netbirdio/netbird/management/server/networks/resources"
- "github.com/netbirdio/netbird/management/server/networks/routers"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
- "github.com/netbirdio/netbird/management/server/store"
- "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
)
@@ -106,90 +85,6 @@ type PerformanceMetrics struct {
MaxMsPerOpCICD float64
}
-func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
- store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
- if err != nil {
- t.Fatalf("Failed to create test store: %v", err)
- }
- t.Cleanup(cleanup)
-
- metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
- if err != nil {
- t.Fatalf("Failed to create metrics: %v", err)
- }
-
- peersUpdateManager := server.NewPeersUpdateManager(nil)
- updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId)
- done := make(chan struct{})
- if validateUpdate {
- go func() {
- if expectedPeerUpdate != nil {
- peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
- } else {
- peerShouldNotReceiveUpdate(t, updMsg)
- }
- close(done)
- }()
- }
-
- geoMock := &geolocation.Mock{}
- validatorMock := server.MockIntegratedValidator{}
- proxyController := integrations.NewController(store)
- userManager := users.NewManager(store)
- permissionsManager := permissions.NewManager(store)
- settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
- am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
- if err != nil {
- t.Fatalf("Failed to create manager: %v", err)
- }
-
- // @note this is required so that PAT's validate from store, but JWT's are mocked
- authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
- authManagerMock := &auth.MockManager{
- ValidateAndParseTokenFunc: mockValidateAndParseToken,
- EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
- MarkPATUsedFunc: authManager.MarkPATUsed,
- GetPATInfoFunc: authManager.GetPATInfo,
- }
-
- networksManagerMock := networks.NewManagerMock()
- resourcesManagerMock := resources.NewManagerMock()
- routersManagerMock := routers.NewManagerMock()
- groupsManagerMock := groups.NewManagerMock()
- peersManager := peers.NewManager(store, permissionsManager)
-
- apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
- if err != nil {
- t.Fatalf("Failed to create API handler: %v", err)
- }
-
- return apiHandler, am, done
-}
-
-func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) {
- t.Helper()
- select {
- case msg := <-updateMessage:
- t.Errorf("Unexpected message received: %+v", msg)
- case <-time.After(500 * time.Millisecond):
- return
- }
-}
-
-func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
- t.Helper()
-
- select {
- case msg := <-updateMessage:
- if msg == nil {
- t.Errorf("Received nil update message, expected valid message")
- }
- assert.Equal(t, expected, msg)
- case <-time.After(500 * time.Millisecond):
- t.Errorf("Timed out waiting for update message")
- }
-}
-
func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request {
t.Helper()
@@ -222,11 +117,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta
return content, expectedStatus == http.StatusOK
}
-func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) {
+func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) {
b.Helper()
ctx := context.Background()
- account, err := am.GetAccount(ctx, TestAccountId)
+ acc, err := am.GetAccount(ctx, TestAccountId)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
@@ -242,23 +137,23 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true},
UserID: TestUserId,
}
- account.Peers[peer.ID] = peer
+ acc.Peers[peer.ID] = peer
}
// Create users
for i := 0; i < users; i++ {
user := &types.User{
Id: fmt.Sprintf("olduser-%d", i),
- AccountID: account.Id,
+ AccountID: acc.Id,
Role: types.UserRoleUser,
}
- account.Users[user.Id] = user
+ acc.Users[user.Id] = user
}
for i := 0; i < setupKeys; i++ {
key := &types.SetupKey{
Id: fmt.Sprintf("oldkey-%d", i),
- AccountID: account.Id,
+ AccountID: acc.Id,
AutoGroups: []string{"someGroupID"},
UpdatedAt: time.Now().UTC(),
ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)),
@@ -266,11 +161,11 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
Type: "reusable",
UsageLimit: 0,
}
- account.SetupKeys[key.Id] = key
+ acc.SetupKeys[key.Id] = key
}
// Create groups and policies
- account.Policies = make([]*types.Policy, 0, groups)
+ acc.Policies = make([]*types.Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &types.Group{
@@ -281,7 +176,7 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
}
- account.Groups[groupID] = group
+ acc.Groups[groupID] = group
// Create a policy for this group
policy := &types.Policy{
@@ -301,10 +196,10 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
},
},
}
- account.Policies = append(account.Policies, policy)
+ acc.Policies = append(acc.Policies, policy)
}
- account.PostureChecks = []*posture.Checks{
+ acc.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
@@ -316,52 +211,38 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro
},
}
- err = am.Store.SaveAccount(context.Background(), account)
+ store := am.GetStore()
+
+ err = store.SaveAccount(context.Background(), acc)
if err != nil {
b.Fatalf("Failed to save account: %v", err)
}
}
-func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
+func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
b.Helper()
- branch := os.Getenv("GIT_BRANCH")
- if branch == "" {
- b.Fatalf("environment variable GIT_BRANCH is not set")
- }
-
if recorder.Code != http.StatusOK {
b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code)
}
+ EvaluateBenchmarkResults(b, testCase, duration, module, operation)
+
+}
+
+func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) {
+ b.Helper()
+
+ branch := os.Getenv("GIT_BRANCH")
+ if branch == "" && os.Getenv("CI") == "true" {
+ b.Fatalf("environment variable GIT_BRANCH is not set")
+ }
+
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch)
gauge.Set(msPerOp)
b.ReportMetric(msPerOp, "ms/op")
-
-}
-
-func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
- userAuth := nbcontext.UserAuth{}
-
- switch token {
- case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
- userAuth.UserId = token
- userAuth.AccountId = "testAccountId"
- userAuth.Domain = "test.com"
- userAuth.DomainCategory = "private"
- case "otherUserId":
- userAuth.UserId = "otherUserId"
- userAuth.AccountId = "otherAccountId"
- userAuth.Domain = "other.com"
- userAuth.DomainCategory = "private"
- case "invalidToken":
- return userAuth, nil, errors.New("invalid token")
- }
-
- jwtToken := jwt.New(jwt.SigningMethodHS256)
- return userAuth, jwtToken, nil
}
diff --git a/management/server/loginfilter.go b/management/server/loginfilter.go
new file mode 100644
index 000000000..8604af6e2
--- /dev/null
+++ b/management/server/loginfilter.go
@@ -0,0 +1,160 @@
+package server
+
+import (
+ "hash/fnv"
+ "math"
+ "sync"
+ "time"
+
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+)
+
+const (
+ reconnThreshold = 5 * time.Minute
+ baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
+ reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
+ metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
+)
+
+type lfConfig struct {
+ reconnThreshold time.Duration
+ baseBlockDuration time.Duration
+ reconnLimitForBan int
+ metaChangeLimit int
+}
+
+func initCfg() *lfConfig {
+ return &lfConfig{
+ reconnThreshold: reconnThreshold,
+ baseBlockDuration: baseBlockDuration,
+ reconnLimitForBan: reconnLimitForBan,
+ metaChangeLimit: metaChangeLimit,
+ }
+}
+
+type loginFilter struct {
+ mu sync.RWMutex
+ cfg *lfConfig
+ logged map[string]*peerState
+}
+
+type peerState struct {
+ currentHash uint64
+ sessionCounter int
+ sessionStart time.Time
+ lastSeen time.Time
+ isBanned bool
+ banLevel int
+ banExpiresAt time.Time
+ metaChangeCounter int
+ metaChangeWindowStart time.Time
+}
+
+func newLoginFilter() *loginFilter {
+ return newLoginFilterWithCfg(initCfg())
+}
+
+func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter {
+ return &loginFilter{
+ logged: make(map[string]*peerState),
+ cfg: cfg,
+ }
+}
+
+func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool {
+ l.mu.RLock()
+ defer func() {
+ l.mu.RUnlock()
+ }()
+ state, ok := l.logged[wgPubKey]
+ if !ok {
+ return true
+ }
+ if state.isBanned && time.Now().Before(state.banExpiresAt) {
+ return false
+ }
+ if metaHash != state.currentHash {
+ if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit {
+ return false
+ }
+ }
+ return true
+}
+
+func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
+ now := time.Now()
+ l.mu.Lock()
+ defer func() {
+ l.mu.Unlock()
+ }()
+
+ state, ok := l.logged[wgPubKey]
+
+ if !ok {
+ l.logged[wgPubKey] = &peerState{
+ currentHash: metaHash,
+ sessionCounter: 1,
+ sessionStart: now,
+ lastSeen: now,
+ metaChangeWindowStart: now,
+ metaChangeCounter: 1,
+ }
+ return
+ }
+
+ if state.isBanned && now.After(state.banExpiresAt) {
+ state.isBanned = false
+ }
+
+ if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) {
+ state.banLevel = 0
+ }
+
+ if metaHash != state.currentHash {
+ if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) {
+ state.metaChangeWindowStart = now
+ state.metaChangeCounter = 1
+ } else {
+ state.metaChangeCounter++
+ }
+ state.currentHash = metaHash
+ state.sessionCounter = 1
+ state.sessionStart = now
+ state.lastSeen = now
+ return
+ }
+
+ state.sessionCounter++
+ if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold {
+ state.isBanned = true
+ state.banLevel++
+
+ backoffFactor := math.Pow(2, float64(state.banLevel-1))
+ duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor)
+ state.banExpiresAt = now.Add(duration)
+
+ state.sessionCounter = 0
+ state.sessionStart = now
+ }
+ state.lastSeen = now
+}
+
+func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
+ h := fnv.New64a()
+
+ h.Write([]byte(meta.WtVersion))
+ h.Write([]byte(meta.OSVersion))
+ h.Write([]byte(meta.KernelVersion))
+ h.Write([]byte(meta.Hostname))
+ h.Write([]byte(meta.SystemSerialNumber))
+ h.Write([]byte(pubip))
+
+ macs := uint64(0)
+ for _, na := range meta.NetworkAddresses {
+ for _, r := range na.Mac {
+ macs += uint64(r)
+ }
+ }
+
+ return h.Sum64() + macs
+}
diff --git a/management/server/loginfilter_test.go b/management/server/loginfilter_test.go
new file mode 100644
index 000000000..65782dd9d
--- /dev/null
+++ b/management/server/loginfilter_test.go
@@ -0,0 +1,275 @@
+package server
+
+import (
+ "hash/fnv"
+ "math"
+ "math/rand"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+)
+
+func testAdvancedCfg() *lfConfig {
+ return &lfConfig{
+ reconnThreshold: 50 * time.Millisecond,
+ baseBlockDuration: 100 * time.Millisecond,
+ reconnLimitForBan: 3,
+ metaChangeLimit: 2,
+ }
+}
+
+type LoginFilterTestSuite struct {
+ suite.Suite
+ filter *loginFilter
+}
+
+func (s *LoginFilterTestSuite) SetupTest() {
+ s.filter = newLoginFilterWithCfg(testAdvancedCfg())
+}
+
+func TestLoginFilterTestSuite(t *testing.T) {
+ suite.Run(t, new(LoginFilterTestSuite))
+}
+
+func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
+ pubKey := "PUB_KEY_A"
+ meta := uint64(1)
+
+ s.True(s.filter.allowLogin(pubKey, meta))
+
+ s.filter.addLogin(pubKey, meta)
+ s.Require().Contains(s.filter.logged, pubKey)
+ s.Equal(1, s.filter.logged[pubKey].sessionCounter)
+}
+
+func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
+ pubKey := "PUB_KEY_A"
+ meta := uint64(1)
+ limit := s.filter.cfg.reconnLimitForBan
+
+ for i := 0; i <= limit; i++ {
+ s.filter.addLogin(pubKey, meta)
+ }
+
+ s.False(s.filter.allowLogin(pubKey, meta))
+ s.Require().Contains(s.filter.logged, pubKey)
+ s.True(s.filter.logged[pubKey].isBanned)
+}
+
+func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
+ pubKey := "PUB_KEY_A"
+ meta := uint64(1)
+ limit := s.filter.cfg.reconnLimitForBan
+ baseBan := s.filter.cfg.baseBlockDuration
+
+ for i := 0; i <= limit; i++ {
+ s.filter.addLogin(pubKey, meta)
+ }
+ s.Require().Contains(s.filter.logged, pubKey)
+ s.True(s.filter.logged[pubKey].isBanned)
+ s.Equal(1, s.filter.logged[pubKey].banLevel)
+ firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
+ s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
+
+ s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
+ s.filter.logged[pubKey].isBanned = false
+
+ for i := 0; i <= limit; i++ {
+ s.filter.addLogin(pubKey, meta)
+ }
+ s.True(s.filter.logged[pubKey].isBanned)
+ s.Equal(2, s.filter.logged[pubKey].banLevel)
+ secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
+ expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
+ s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
+}
+
+func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
+ pubKey := "PUB_KEY_A"
+ meta := uint64(1)
+
+ s.filter.logged[pubKey] = &peerState{
+ isBanned: true,
+ banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
+ }
+
+ s.True(s.filter.allowLogin(pubKey, meta))
+
+ s.filter.addLogin(pubKey, meta)
+ s.Require().Contains(s.filter.logged, pubKey)
+ s.False(s.filter.logged[pubKey].isBanned)
+}
+
+func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
+ pubKey := "PUB_KEY_A"
+ meta := uint64(1)
+
+ s.filter.logged[pubKey] = &peerState{
+ currentHash: meta,
+ banLevel: 3,
+ lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
+ }
+
+ s.filter.addLogin(pubKey, meta)
+ s.Require().Contains(s.filter.logged, pubKey)
+ s.Equal(0, s.filter.logged[pubKey].banLevel)
+}
+
+func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
+ pubKey := "PUB_KEY_A"
+ limit := s.filter.cfg.metaChangeLimit
+
+ for i := range limit {
+ s.filter.addLogin(pubKey, uint64(i+1))
+ }
+
+ s.Require().Contains(s.filter.logged, pubKey)
+ s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
+
+ isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
+
+ s.False(isAllowed, "should block new meta hash after limit is reached")
+}
+
+func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
+ pubKey := "PUB_KEY_A"
+ meta1 := uint64(1)
+ meta2 := uint64(2)
+ meta3 := uint64(3)
+
+ s.filter.addLogin(pubKey, meta1)
+ s.filter.addLogin(pubKey, meta2)
+ s.Require().Contains(s.filter.logged, pubKey)
+ s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
+ s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
+
+ s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
+
+ s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
+
+ s.filter.addLogin(pubKey, meta3)
+ s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
+}
+
+func BenchmarkHashingMethods(b *testing.B) {
+ meta := nbpeer.PeerSystemMeta{
+ WtVersion: "1.25.1",
+ OSVersion: "Ubuntu 22.04.3 LTS",
+ KernelVersion: "5.15.0-76-generic",
+ Hostname: "prod-server-database-01",
+ SystemSerialNumber: "PC-1234567890",
+ NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
+ }
+ pubip := "8.8.8.8"
+
+ var resultString string
+ var resultUint uint64
+
+ b.Run("BuilderString", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ resultString = builderString(meta, pubip)
+ }
+ })
+
+ b.Run("FnvHashToString", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ resultString = fnvHashToString(meta, pubip)
+ }
+ })
+
+ b.Run("FnvHashToUint64 - used", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ resultUint = metaHash(meta, pubip)
+ }
+ })
+
+ _ = resultString
+ _ = resultUint
+}
+
+func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
+ h := fnv.New64a()
+
+ if len(meta.NetworkAddresses) != 0 {
+ for _, na := range meta.NetworkAddresses {
+ h.Write([]byte(na.Mac))
+ }
+ }
+
+ h.Write([]byte(meta.WtVersion))
+ h.Write([]byte(meta.OSVersion))
+ h.Write([]byte(meta.KernelVersion))
+ h.Write([]byte(meta.Hostname))
+ h.Write([]byte(meta.SystemSerialNumber))
+ h.Write([]byte(pubip))
+
+ return strconv.FormatUint(h.Sum64(), 16)
+}
+
+func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
+ mac := getMacAddress(meta.NetworkAddresses)
+ estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
+ len(pubip) + len(mac) + 6
+
+ var b strings.Builder
+ b.Grow(estimatedSize)
+
+ b.WriteString(meta.WtVersion)
+ b.WriteByte('|')
+ b.WriteString(meta.OSVersion)
+ b.WriteByte('|')
+ b.WriteString(meta.KernelVersion)
+ b.WriteByte('|')
+ b.WriteString(meta.Hostname)
+ b.WriteByte('|')
+ b.WriteString(meta.SystemSerialNumber)
+ b.WriteByte('|')
+ b.WriteString(pubip)
+
+ return b.String()
+}
+
+func getMacAddress(nas []nbpeer.NetworkAddress) string {
+ if len(nas) == 0 {
+ return ""
+ }
+ macs := make([]string, 0, len(nas))
+ for _, na := range nas {
+ macs = append(macs, na.Mac)
+ }
+ return strings.Join(macs, "/")
+}
+
+func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
+ filter := newLoginFilterWithCfg(testAdvancedCfg())
+ numKeys := 100000
+ pubKeys := make([]string, numKeys)
+ for i := range numKeys {
+ pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
+ }
+
+ b.ResetTimer()
+ b.RunParallel(func(pb *testing.PB) {
+ r := rand.New(rand.NewSource(time.Now().UnixNano()))
+
+ for pb.Next() {
+ key := pubKeys[r.Intn(numKeys)]
+ meta := r.Uint64()
+
+ if filter.allowLogin(key, meta) {
+ filter.addLogin(key, meta)
+ }
+ }
+ })
+}
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index 6f9c2696f..003385eb5 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -95,6 +95,8 @@ type MockAccountManager struct {
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
+ ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
+ RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() account.ExternalCacheManager
@@ -121,8 +123,10 @@ type MockAccountManager struct {
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
- UpdateAccountPeersFunc func(ctx context.Context, accountID string)
- BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
+
+ AllowSyncFunc func(string, uint64) bool
+ UpdateAccountPeersFunc func(ctx context.Context, accountID string)
+ BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -605,6 +609,20 @@ func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string,
return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented")
}
+func (am *MockAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
+ if am.ApproveUserFunc != nil {
+ return am.ApproveUserFunc(ctx, accountID, initiatorUserID, targetUserID)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method ApproveUser is not implemented")
+}
+
+func (am *MockAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
+ if am.RejectUserFunc != nil {
+ return am.RejectUserFunc(ctx, accountID, initiatorUserID, targetUserID)
+ }
+ return status.Errorf(codes.Unimplemented, "method RejectUser is not implemented")
+}
+
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if am.GetNameServerGroupFunc != nil {
@@ -953,3 +971,10 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n
}
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
}
+
+func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
+ if am.AllowSyncFunc != nil {
+ return am.AllowSyncFunc(key, hash)
+ }
+ return true
+}
diff --git a/management/server/peer.go b/management/server/peer.go
index 8af71cbd2..81f037499 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -368,10 +368,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
- if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
- return fmt.Errorf("failed to remove peer from groups: %w", err)
- }
-
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -493,6 +489,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if err != nil {
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found")
}
+ if user.PendingApproval {
+ return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
+ }
groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index c4822aa62..31c309430 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -26,6 +26,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/server/config"
+ "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -989,19 +990,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
- minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
- minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
+ testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
}
- if msPerOp < minExpected {
- b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
- }
-
- if msPerOp > (maxExpected * 1.1) {
- b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
+ if msPerOp > maxExpected {
+ b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
@@ -1609,7 +1605,6 @@ func Test_LoginPeer(t *testing.T) {
testCases := []struct {
name string
setupKey string
- wireGuardPubKey string
expectExtraDNSLabelsMismatch bool
extraDNSLabels []string
expectLoginError bool
@@ -2388,3 +2383,186 @@ func TestBufferUpdateAccountPeers(t *testing.T) {
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
}
+
+func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create account
+ account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Create user pending approval
+ pendingUser := types.NewRegularUser("pending-user")
+ pendingUser.AccountID = account.Id
+ pendingUser.Blocked = true
+ pendingUser.PendingApproval = true
+ err = manager.Store.SaveUser(context.Background(), pendingUser)
+ require.NoError(t, err)
+
+ // Try to add peer with pending approval user
+ key, err := wgtypes.GenerateKey()
+ require.NoError(t, err)
+
+ peer := &nbpeer.Peer{
+ Key: key.PublicKey().String(),
+ Name: "test-peer",
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "test-peer",
+ OS: "linux",
+ },
+ }
+
+ _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "user pending approval cannot add peers")
+}
+
+func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create account
+ account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Create regular user (not pending approval)
+ regularUser := types.NewRegularUser("regular-user")
+ regularUser.AccountID = account.Id
+ err = manager.Store.SaveUser(context.Background(), regularUser)
+ require.NoError(t, err)
+
+ // Try to add peer with regular user
+ key, err := wgtypes.GenerateKey()
+ require.NoError(t, err)
+
+ peer := &nbpeer.Peer{
+ Key: key.PublicKey().String(),
+ Name: "test-peer",
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "test-peer",
+ OS: "linux",
+ },
+ }
+
+ _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer)
+ require.NoError(t, err, "Regular user should be able to add peers")
+}
+
+func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create account
+ account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Create user pending approval
+ pendingUser := types.NewRegularUser("pending-user")
+ pendingUser.AccountID = account.Id
+ pendingUser.Blocked = true
+ pendingUser.PendingApproval = true
+ err = manager.Store.SaveUser(context.Background(), pendingUser)
+ require.NoError(t, err)
+
+ // Create a peer using AddPeer method for the pending user (simulate existing peer)
+ key, err := wgtypes.GenerateKey()
+ require.NoError(t, err)
+
+ // Set the user to not be pending initially so peer can be added
+ pendingUser.Blocked = false
+ pendingUser.PendingApproval = false
+ err = manager.Store.SaveUser(context.Background(), pendingUser)
+ require.NoError(t, err)
+
+ // Add peer using regular flow
+ newPeer := &nbpeer.Peer{
+ Key: key.PublicKey().String(),
+ Name: "test-peer",
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "test-peer",
+ OS: "linux",
+ WtVersion: "0.28.0",
+ },
+ }
+ existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer)
+ require.NoError(t, err)
+
+ // Now set the user back to pending approval after peer was created
+ pendingUser.Blocked = true
+ pendingUser.PendingApproval = true
+ err = manager.Store.SaveUser(context.Background(), pendingUser)
+ require.NoError(t, err)
+
+ // Try to login with pending approval user
+ login := types.PeerLogin{
+ WireGuardPubKey: existingPeer.Key,
+ UserID: pendingUser.Id,
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "test-peer",
+ OS: "linux",
+ },
+ }
+
+ _, _, _, err = manager.LoginPeer(context.Background(), login)
+ require.Error(t, err)
+ e, ok := status.FromError(err)
+ require.True(t, ok, "error is not a gRPC status error")
+ assert.Equal(t, status.PermissionDenied, e.Type(), "expected PermissionDenied error code")
+}
+
+func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create account
+ account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Create regular user (not pending approval)
+ regularUser := types.NewRegularUser("regular-user")
+ regularUser.AccountID = account.Id
+ err = manager.Store.SaveUser(context.Background(), regularUser)
+ require.NoError(t, err)
+
+ // Add peer using regular flow for the regular user
+ key, err := wgtypes.GenerateKey()
+ require.NoError(t, err)
+
+ newPeer := &nbpeer.Peer{
+ Key: key.PublicKey().String(),
+ Name: "test-peer",
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "test-peer",
+ OS: "linux",
+ WtVersion: "0.28.0",
+ },
+ }
+ existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer)
+ require.NoError(t, err)
+
+ // Try to login with regular user
+ login := types.PeerLogin{
+ WireGuardPubKey: existingPeer.Key,
+ UserID: regularUser.Id,
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "test-peer",
+ OS: "linux",
+ },
+ }
+
+ _, _, _, err = manager.LoginPeer(context.Background(), login)
+ require.NoError(t, err, "Regular user should be able to login peers")
+}
diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go
index 50e36a880..cb135f4ac 100644
--- a/management/server/peers/manager.go
+++ b/management/server/peers/manager.go
@@ -18,6 +18,7 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
+ GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
}
type managerImpl struct {
@@ -61,3 +62,7 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}
+
+func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
+ return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
+}
diff --git a/management/server/peers/manager_mock.go b/management/server/peers/manager_mock.go
index b247a1752..994f8346b 100644
--- a/management/server/peers/manager_mock.go
+++ b/management/server/peers/manager_mock.go
@@ -79,3 +79,18 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
}
+
+// GetPeersByGroupIDs mocks base method.
+func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs)
+ ret0, _ := ret[0].([]*peer.Peer)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs.
+func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
+}
diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go
index 0ab244243..891fa59bb 100644
--- a/management/server/permissions/manager.go
+++ b/management/server/permissions/manager.go
@@ -54,10 +54,14 @@ func (m *managerImpl) ValidateUserPermissions(
return false, status.NewUserNotFoundError(userID)
}
- if user.IsBlocked() {
+ if user.IsBlocked() && !user.PendingApproval {
return false, status.NewUserBlockedError()
}
+ if user.IsBlocked() && user.PendingApproval {
+ return false, status.NewUserPendingApprovalError()
+ }
+
if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return false, err
}
diff --git a/management/server/policy.go b/management/server/policy.go
index 312fd53b2..3adee6397 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -167,10 +167,22 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" {
- _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
+ existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil {
return err
}
+
+ // TODO: Refactor to support multiple rules per policy
+ existingRuleIDs := make(map[string]bool)
+ for _, rule := range existingPolicy.Rules {
+ existingRuleIDs[rule.ID] = true
+ }
+
+ for _, rule := range policy.Rules {
+ if rule.ID != "" && !existingRuleIDs[rule.ID] {
+ return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
+ }
+ }
} else {
policy.ID = xid.New().String()
policy.AccountID = accountID
diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go
index 6ef93f0d1..027938320 100644
--- a/management/server/store/sql_store.go
+++ b/management/server/store/sql_store.go
@@ -914,7 +914,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
var account types.Account
- result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account)
+ result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account)
if result.Error != nil {
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -1399,7 +1399,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
PeerID: peerID,
}
- err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
+ err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(peer).Error
@@ -1414,7 +1414,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
// RemovePeerFromGroup removes a peer from a group
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
- err := s.db.WithContext(ctx).
+ err := s.db.
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
if err != nil {
@@ -1427,7 +1427,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
- err := s.db.WithContext(ctx).
+ err := s.db.
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
if err != nil {
@@ -2015,7 +2015,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
}
func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error {
- return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ return s.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil {
return fmt.Errorf("delete policy rules: %w", err)
}
@@ -2706,7 +2706,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
}
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
- tx := s.db.WithContext(ctx)
+ tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -2847,3 +2847,22 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
}
return nil
}
+
+func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) {
+ if len(groupIDs) == 0 {
+ return []*nbpeer.Peer{}, nil
+ }
+
+ var peers []*nbpeer.Peer
+ peerIDsSubquery := s.db.Model(&types.GroupPeer{}).
+ Select("DISTINCT peer_id").
+ Where("account_id = ? AND group_id IN ?", accountID, groupIDs)
+
+ result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get peers by group IDs")
+ }
+
+ return peers, nil
+}
diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go
index 935b0a595..d40c4664c 100644
--- a/management/server/store/sql_store_test.go
+++ b/management/server/store/sql_store_test.go
@@ -3607,3 +3607,113 @@ func intToIPv4(n uint32) net.IP {
binary.BigEndian.PutUint32(ip, n)
return ip
}
+
+func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ group1ID := "test-group-1"
+ group2ID := "test-group-2"
+ emptyGroupID := "empty-group"
+
+ peer1 := "cfefqs706sqkneg59g4g"
+ peer2 := "cfeg6sf06sqkneg59g50"
+
+ tests := []struct {
+ name string
+ groupIDs []string
+ expectedPeers []string
+ expectedCount int
+ }{
+ {
+ name: "retrieve peers from single group with multiple peers",
+ groupIDs: []string{group1ID},
+ expectedPeers: []string{peer1, peer2},
+ expectedCount: 2,
+ },
+ {
+ name: "retrieve peers from single group with one peer",
+ groupIDs: []string{group2ID},
+ expectedPeers: []string{peer1},
+ expectedCount: 1,
+ },
+ {
+ name: "retrieve peers from multiple groups (with overlap)",
+ groupIDs: []string{group1ID, group2ID},
+ expectedPeers: []string{peer1, peer2}, // should deduplicate
+ expectedCount: 2,
+ },
+ {
+ name: "retrieve peers from existing 'All' group",
+ groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data
+ expectedPeers: []string{peer1, peer2},
+ expectedCount: 2,
+ },
+ {
+ name: "retrieve peers from empty group",
+ groupIDs: []string{emptyGroupID},
+ expectedPeers: []string{},
+ expectedCount: 0,
+ },
+ {
+ name: "retrieve peers from non-existing group",
+ groupIDs: []string{"non-existing-group"},
+ expectedPeers: []string{},
+ expectedCount: 0,
+ },
+ {
+ name: "empty group IDs list",
+ groupIDs: []string{},
+ expectedPeers: []string{},
+ expectedCount: 0,
+ },
+ {
+ name: "mix of existing and non-existing groups",
+ groupIDs: []string{group1ID, "non-existing-group"},
+ expectedPeers: []string{peer1, peer2},
+ expectedCount: 2,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ ctx := context.Background()
+
+ groups := []*types.Group{
+ {
+ ID: group1ID,
+ AccountID: accountID,
+ },
+ {
+ ID: group2ID,
+ AccountID: accountID,
+ },
+ }
+ require.NoError(t, store.CreateGroups(ctx, accountID, groups))
+
+ require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID))
+ require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID))
+ require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID))
+
+ peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs)
+ require.NoError(t, err)
+ require.Len(t, peers, tt.expectedCount)
+
+ if tt.expectedCount > 0 {
+ actualPeerIDs := make([]string, len(peers))
+ for i, peer := range peers {
+ actualPeerIDs[i] = peer.ID
+ }
+ assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs)
+
+ // Verify all returned peers belong to the correct account
+ for _, peer := range peers {
+ assert.Equal(t, accountID, peer.AccountID)
+ }
+ }
+ })
+ }
+}
diff --git a/management/server/store/store.go b/management/server/store/store.go
index 545549410..3c9d896b0 100644
--- a/management/server/store/store.go
+++ b/management/server/store/store.go
@@ -136,6 +136,7 @@ type Store interface {
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
+ GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go
index ac6ff2ea8..d4301802f 100644
--- a/management/server/telemetry/grpc_metrics.go
+++ b/management/server/telemetry/grpc_metrics.go
@@ -4,20 +4,28 @@ import (
"context"
"time"
+ "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
+const AccountIDLabel = "account_id"
+const HighLatencyThreshold = time.Second * 7
+
// GRPCMetrics are gRPC server metrics
type GRPCMetrics struct {
- meter metric.Meter
- syncRequestsCounter metric.Int64Counter
- loginRequestsCounter metric.Int64Counter
- getKeyRequestsCounter metric.Int64Counter
- activeStreamsGauge metric.Int64ObservableGauge
- syncRequestDuration metric.Int64Histogram
- loginRequestDuration metric.Int64Histogram
- channelQueueLength metric.Int64Histogram
- ctx context.Context
+ meter metric.Meter
+ syncRequestsCounter metric.Int64Counter
+ syncRequestsBlockedCounter metric.Int64Counter
+ syncRequestHighLatencyCounter metric.Int64Counter
+ loginRequestsCounter metric.Int64Counter
+ loginRequestsBlockedCounter metric.Int64Counter
+ loginRequestHighLatencyCounter metric.Int64Counter
+ getKeyRequestsCounter metric.Int64Counter
+ activeStreamsGauge metric.Int64ObservableGauge
+ syncRequestDuration metric.Int64Histogram
+ loginRequestDuration metric.Int64Histogram
+ channelQueueLength metric.Int64Histogram
+ ctx context.Context
}
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
@@ -30,6 +38,22 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
+ syncRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.sync.request.blocked.counter",
+ metric.WithUnit("1"),
+ metric.WithDescription("Number of sync gRPC requests from blocked peers"),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
+ metric.WithUnit("1"),
+ metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
+ )
+ if err != nil {
+ return nil, err
+ }
+
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -38,6 +62,22 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
+ loginRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.login.request.blocked.counter",
+ metric.WithUnit("1"),
+ metric.WithDescription("Number of login gRPC requests from blocked peers"),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ loginRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.login.request.high.latency.counter",
+ metric.WithUnit("1"),
+ metric.WithDescription("Number of login gRPC requests from the peers that took longer than the threshold to authenticate and receive initial configuration and relay credentials"),
+ )
+ if err != nil {
+ return nil, err
+ }
+
getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of key gRPC requests from the peers to get the server's public WireGuard key"),
@@ -83,15 +123,19 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
}
return &GRPCMetrics{
- meter: meter,
- syncRequestsCounter: syncRequestsCounter,
- loginRequestsCounter: loginRequestsCounter,
- getKeyRequestsCounter: getKeyRequestsCounter,
- activeStreamsGauge: activeStreamsGauge,
- syncRequestDuration: syncRequestDuration,
- loginRequestDuration: loginRequestDuration,
- channelQueueLength: channelQueue,
- ctx: ctx,
+ meter: meter,
+ syncRequestsCounter: syncRequestsCounter,
+ syncRequestsBlockedCounter: syncRequestsBlockedCounter,
+ syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
+ loginRequestsCounter: loginRequestsCounter,
+ loginRequestsBlockedCounter: loginRequestsBlockedCounter,
+ loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
+ getKeyRequestsCounter: getKeyRequestsCounter,
+ activeStreamsGauge: activeStreamsGauge,
+ syncRequestDuration: syncRequestDuration,
+ loginRequestDuration: loginRequestDuration,
+ channelQueueLength: channelQueue,
+ ctx: ctx,
}, err
}
@@ -100,6 +144,11 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequest() {
grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1)
}
+// CountSyncRequestBlocked counts the number of gRPC sync requests from blocked peers
+func (grpcMetrics *GRPCMetrics) CountSyncRequestBlocked() {
+ grpcMetrics.syncRequestsBlockedCounter.Add(grpcMetrics.ctx, 1)
+}
+
// CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API
func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() {
grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1)
@@ -110,14 +159,25 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() {
grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1)
}
+// CountLoginRequestBlocked counts the number of gRPC login requests from blocked peers
+func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
+ grpcMetrics.loginRequestsBlockedCounter.Add(grpcMetrics.ctx, 1)
+}
+
// CountLoginRequestDuration counts the duration of the login gRPC requests
-func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration) {
+func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
+ if duration > HighLatencyThreshold {
+ grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
+ }
}
// CountSyncRequestDuration counts the duration of the sync gRPC requests
-func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration) {
+func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
+ if duration > HighLatencyThreshold {
+ grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
+ }
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
diff --git a/management/server/types/account.go b/management/server/types/account.go
index 9ac2568a0..ca075b9f6 100644
--- a/management/server/types/account.go
+++ b/management/server/types/account.go
@@ -302,7 +302,11 @@ func (a *Account) GetPeerNetworkMap(
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
- zones = append(zones, peersCustomZone)
+ records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
+ zones = append(zones, nbdns.CustomZone{
+ Domain: peersCustomZone.Domain,
+ Records: records,
+ })
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
@@ -1651,3 +1655,24 @@ func peerSupportsPortRanges(peerVer string) bool {
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
return err == nil && meetMinVer
}
+
+// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
+func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
+ filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
+ peerIPs := make(map[string]struct{})
+
+ // Add peer's own IP to include its own DNS records
+ peerIPs[peer.IP.String()] = struct{}{}
+
+ for _, peerToConnect := range peersToConnect {
+ peerIPs[peerToConnect.IP.String()] = struct{}{}
+ }
+
+ for _, record := range customZone.Records {
+ if _, exists := peerIPs[record.RData]; exists {
+ filteredRecords = append(filteredRecords, record)
+ }
+ }
+
+ return filteredRecords
+}
diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go
index f8ab1d627..cd221b590 100644
--- a/management/server/types/account_test.go
+++ b/management/server/types/account_test.go
@@ -2,14 +2,17 @@ package types
import (
"context"
+ "fmt"
"net"
"net/netip"
"slices"
"testing"
+ "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -835,3 +838,109 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
}
+
+func Test_FilterZoneRecordsForPeers(t *testing.T) {
+ tests := []struct {
+ name string
+ peer *nbpeer.Peer
+ customZone nbdns.CustomZone
+ peersToConnect []*nbpeer.Peer
+ expectedRecords []nbdns.SimpleRecord
+ }{
+ {
+ name: "empty peers to connect",
+ customZone: nbdns.CustomZone{
+ Domain: "netbird.cloud.",
+ Records: []nbdns.SimpleRecord{
+ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
+ },
+ },
+ peersToConnect: []*nbpeer.Peer{},
+ peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
+ expectedRecords: []nbdns.SimpleRecord{
+ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
+ },
+ },
+ {
+ name: "multiple peers multiple records match",
+ customZone: nbdns.CustomZone{
+ Domain: "netbird.cloud.",
+ Records: func() []nbdns.SimpleRecord {
+ var records []nbdns.SimpleRecord
+ for i := 1; i <= 100; i++ {
+ records = append(records, nbdns.SimpleRecord{
+ Name: fmt.Sprintf("peer%d.netbird.cloud", i),
+ Type: int(dns.TypeA),
+ Class: nbdns.DefaultClass,
+ TTL: 300,
+ RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
+ })
+ }
+ return records
+ }(),
+ },
+ peersToConnect: func() []*nbpeer.Peer {
+ var peers []*nbpeer.Peer
+ for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
+ peers = append(peers, &nbpeer.Peer{
+ ID: fmt.Sprintf("peer%d", i),
+ IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)),
+ })
+ }
+ return peers
+ }(),
+ peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
+ expectedRecords: func() []nbdns.SimpleRecord {
+ var records []nbdns.SimpleRecord
+ for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
+ records = append(records, nbdns.SimpleRecord{
+ Name: fmt.Sprintf("peer%d.netbird.cloud", i),
+ Type: int(dns.TypeA),
+ Class: nbdns.DefaultClass,
+ TTL: 300,
+ RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
+ })
+ }
+ return records
+ }(),
+ },
+ {
+ name: "peers with multiple DNS labels",
+ customZone: nbdns.CustomZone{
+ Domain: "netbird.cloud.",
+ Records: []nbdns.SimpleRecord{
+ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
+ {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
+ {Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
+ {Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
+ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
+ },
+ },
+ peersToConnect: []*nbpeer.Peer{
+ {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
+ {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
+ },
+ peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
+ expectedRecords: []nbdns.SimpleRecord{
+ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
+ {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
+ {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
+ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
+ assert.Equal(t, len(tt.expectedRecords), len(result))
+ assert.ElementsMatch(t, tt.expectedRecords, result)
+ })
+ }
+}
diff --git a/management/server/types/network.go b/management/server/types/network.go
index f072a4294..ffc019565 100644
--- a/management/server/types/network.go
+++ b/management/server/types/network.go
@@ -12,11 +12,11 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
- "github.com/netbirdio/netbird/shared/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
+ "github.com/netbirdio/netbird/shared/management/proto"
+ "github.com/netbirdio/netbird/shared/management/status"
)
const (
diff --git a/management/server/types/settings.go b/management/server/types/settings.go
index 56c33da3b..b4afb2f5e 100644
--- a/management/server/types/settings.go
+++ b/management/server/types/settings.go
@@ -83,6 +83,9 @@ type ExtraSettings struct {
// PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator
PeerApprovalEnabled bool
+ // UserApprovalRequired enables or disables the need for users joining via domain matching to be approved by an administrator
+ UserApprovalRequired bool
+
// IntegratedValidator is the string enum for the integrated validator type
IntegratedValidator string
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
@@ -99,6 +102,7 @@ type ExtraSettings struct {
func (e *ExtraSettings) Copy() *ExtraSettings {
return &ExtraSettings{
PeerApprovalEnabled: e.PeerApprovalEnabled,
+ UserApprovalRequired: e.UserApprovalRequired,
IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups),
IntegratedValidator: e.IntegratedValidator,
FlowEnabled: e.FlowEnabled,
diff --git a/management/server/types/user.go b/management/server/types/user.go
index 783fe14da..beb3586df 100644
--- a/management/server/types/user.go
+++ b/management/server/types/user.go
@@ -64,6 +64,7 @@ type UserInfo struct {
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
+ PendingApproval bool `json:"pending_approval"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
}
@@ -84,6 +85,8 @@ type User struct {
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
+ // PendingApproval indicates whether the user requires approval before being activated
+ PendingApproval bool
// LastLogin is the last time the user logged in to IdP
LastLogin *time.Time
// CreatedAt records the time the user was created
@@ -141,16 +144,17 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
if userData == nil {
return &UserInfo{
- ID: u.Id,
- Email: "",
- Name: u.ServiceUserName,
- Role: string(u.Role),
- AutoGroups: u.AutoGroups,
- Status: string(UserStatusActive),
- IsServiceUser: u.IsServiceUser,
- IsBlocked: u.Blocked,
- LastLogin: u.GetLastLogin(),
- Issued: u.Issued,
+ ID: u.Id,
+ Email: "",
+ Name: u.ServiceUserName,
+ Role: string(u.Role),
+ AutoGroups: u.AutoGroups,
+ Status: string(UserStatusActive),
+ IsServiceUser: u.IsServiceUser,
+ IsBlocked: u.Blocked,
+ LastLogin: u.GetLastLogin(),
+ Issued: u.Issued,
+ PendingApproval: u.PendingApproval,
}, nil
}
if userData.ID != u.Id {
@@ -163,16 +167,17 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
}
return &UserInfo{
- ID: u.Id,
- Email: userData.Email,
- Name: userData.Name,
- Role: string(u.Role),
- AutoGroups: autoGroups,
- Status: string(userStatus),
- IsServiceUser: u.IsServiceUser,
- IsBlocked: u.Blocked,
- LastLogin: u.GetLastLogin(),
- Issued: u.Issued,
+ ID: u.Id,
+ Email: userData.Email,
+ Name: userData.Name,
+ Role: string(u.Role),
+ AutoGroups: autoGroups,
+ Status: string(userStatus),
+ IsServiceUser: u.IsServiceUser,
+ IsBlocked: u.Blocked,
+ LastLogin: u.GetLastLogin(),
+ Issued: u.Issued,
+ PendingApproval: u.PendingApproval,
}, nil
}
@@ -194,6 +199,7 @@ func (u *User) Copy() *User {
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
+ PendingApproval: u.PendingApproval,
LastLogin: u.LastLogin,
CreatedAt: u.CreatedAt,
Issued: u.Issued,
diff --git a/management/server/user.go b/management/server/user.go
index 4596ee95b..d40d33c6a 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -519,33 +519,46 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
initiatorUser = result
}
- err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
- for _, update := range updates {
- if update == nil {
- return status.Errorf(status.InvalidArgument, "provided user update is nil")
- }
+ var globalErr error
+ for _, update := range updates {
+ if update == nil {
+ return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
+ }
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
)
if err != nil {
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
}
- usersToSave = append(usersToSave, updatedUser)
- addUserEvents = append(addUserEvents, userEvents...)
- peersToExpire = append(peersToExpire, userPeersToExpire...)
if userHadPeers {
updateAccountPeers = true
}
+
+ err = transaction.SaveUser(ctx, updatedUser)
+ if err != nil {
+ return fmt.Errorf("failed to save updated user %s: %w", update.Id, err)
+ }
+
+ usersToSave = append(usersToSave, updatedUser)
+ addUserEvents = append(addUserEvents, userEvents...)
+ peersToExpire = append(peersToExpire, userPeersToExpire...)
+
+ return nil
+ })
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to save user %s: %s", update.Id, err)
+ if len(updates) == 1 {
+ return nil, err
+ }
+ globalErr = errors.Join(globalErr, err)
+ // continue when updating multiple users
}
- return transaction.SaveUsers(ctx, usersToSave)
- })
- if err != nil {
- return nil, err
}
- var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates))
+ var updatedUsersInfo = make([]*types.UserInfo, 0, len(usersToSave))
userInfos, err := am.GetUsersFromAccount(ctx, accountID, initiatorUserID)
if err != nil {
@@ -578,7 +591,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
am.UpdateAccountPeers(ctx, accountID)
}
- return updatedUsersInfo, nil
+ return updatedUsersInfo, globalErr
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
@@ -643,7 +656,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
}
transferredOwnerRole = result
- userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id)
+ userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, updatedUser.AccountID, update.Id)
if err != nil {
return false, nil, nil, nil, err
}
@@ -929,6 +942,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
+ if peer.UserID == "" {
+ // we do not want to expire peers that are added via setup key
+ continue
+ }
+
if peer.Status.LoginExpired {
continue
}
@@ -947,6 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
+ log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -1194,3 +1213,77 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return userWithPermissions, nil
}
+
+// ApproveUser approves a user that is pending approval
+func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update)
+ if err != nil {
+ return nil, status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return nil, status.NewPermissionDeniedError()
+ }
+
+ user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
+ if err != nil {
+ return nil, err
+ }
+
+ if user.AccountID != accountID {
+ return nil, status.NewUserNotFoundError(targetUserID)
+ }
+
+ if !user.PendingApproval {
+ return nil, status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID)
+ }
+
+ user.Blocked = false
+ user.PendingApproval = false
+
+ err = am.Store.SaveUser(ctx, user)
+ if err != nil {
+ return nil, err
+ }
+
+ am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserApproved, nil)
+
+ userInfo, err := am.getUserInfo(ctx, user, accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ return userInfo, nil
+}
+
+// RejectUser rejects a user that is pending approval by deleting them
+func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
+ allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
+ if err != nil {
+ return status.NewPermissionValidationError(err)
+ }
+ if !allowed {
+ return status.NewPermissionDeniedError()
+ }
+
+ user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
+ if err != nil {
+ return err
+ }
+
+ if user.AccountID != accountID {
+ return status.NewUserNotFoundError(targetUserID)
+ }
+
+ if !user.PendingApproval {
+ return status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID)
+ }
+
+ err = am.DeleteUser(ctx, accountID, initiatorUserID, targetUserID)
+ if err != nil {
+ return err
+ }
+
+ am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserRejected, nil)
+
+ return nil
+}
diff --git a/management/server/user_test.go b/management/server/user_test.go
index 8ab0c1565..9638559f9 100644
--- a/management/server/user_test.go
+++ b/management/server/user_test.go
@@ -1746,3 +1746,117 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
return permissions
}
+
+func TestApproveUser(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create account with admin and pending approval user
+ account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Create admin user
+ adminUser := types.NewAdminUser("admin-user")
+ adminUser.AccountID = account.Id
+ err = manager.Store.SaveUser(context.Background(), adminUser)
+ require.NoError(t, err)
+
+ // Create user pending approval
+ pendingUser := types.NewRegularUser("pending-user")
+ pendingUser.AccountID = account.Id
+ pendingUser.Blocked = true
+ pendingUser.PendingApproval = true
+ err = manager.Store.SaveUser(context.Background(), pendingUser)
+ require.NoError(t, err)
+
+ // Test successful approval
+ approvedUser, err := manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
+ require.NoError(t, err)
+ assert.False(t, approvedUser.IsBlocked)
+ assert.False(t, approvedUser.PendingApproval)
+
+ // Verify user is updated in store
+ updatedUser, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id)
+ require.NoError(t, err)
+ assert.False(t, updatedUser.Blocked)
+ assert.False(t, updatedUser.PendingApproval)
+
+ // Test approval of non-pending user should fail
+ _, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "not pending approval")
+
+ // Test approval by non-admin should fail
+ regularUser := types.NewRegularUser("regular-user")
+ regularUser.AccountID = account.Id
+ err = manager.Store.SaveUser(context.Background(), regularUser)
+ require.NoError(t, err)
+
+ pendingUser2 := types.NewRegularUser("pending-user-2")
+ pendingUser2.AccountID = account.Id
+ pendingUser2.Blocked = true
+ pendingUser2.PendingApproval = true
+ err = manager.Store.SaveUser(context.Background(), pendingUser2)
+ require.NoError(t, err)
+
+ _, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id)
+ require.Error(t, err)
+}
+
+func TestRejectUser(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create account with admin and pending approval user
+ account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
+ err = manager.Store.SaveAccount(context.Background(), account)
+ require.NoError(t, err)
+
+ // Create admin user
+ adminUser := types.NewAdminUser("admin-user")
+ adminUser.AccountID = account.Id
+ err = manager.Store.SaveUser(context.Background(), adminUser)
+ require.NoError(t, err)
+
+ // Create user pending approval
+ pendingUser := types.NewRegularUser("pending-user")
+ pendingUser.AccountID = account.Id
+ pendingUser.Blocked = true
+ pendingUser.PendingApproval = true
+ err = manager.Store.SaveUser(context.Background(), pendingUser)
+ require.NoError(t, err)
+
+ // Test successful rejection
+ err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
+ require.NoError(t, err)
+
+ // Verify user is deleted from store
+ _, err = manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id)
+ require.Error(t, err)
+
+ // Test rejection of non-pending user should fail
+ regularUser := types.NewRegularUser("regular-user")
+ regularUser.AccountID = account.Id
+ err = manager.Store.SaveUser(context.Background(), regularUser)
+ require.NoError(t, err)
+
+ err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "not pending approval")
+
+ // Test rejection by non-admin should fail
+ pendingUser2 := types.NewRegularUser("pending-user-2")
+ pendingUser2.AccountID = account.Id
+ pendingUser2.Blocked = true
+ pendingUser2.PendingApproval = true
+ err = manager.Store.SaveUser(context.Background(), pendingUser2)
+ require.NoError(t, err)
+
+ err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id)
+ require.Error(t, err)
+}
diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go
index 332127660..12219e29b 100644
--- a/relay/server/listener/ws/listener.go
+++ b/relay/server/listener/ws/listener.go
@@ -73,7 +73,12 @@ func (l *Listener) Shutdown(ctx context.Context) error {
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
connRemoteAddr := remoteAddr(r)
- wsConn, err := websocket.Accept(w, r, nil)
+
+ acceptOptions := &websocket.AcceptOptions{
+ OriginPatterns: []string{"*"},
+ }
+
+ wsConn, err := websocket.Accept(w, r, acceptOptions)
if err != nil {
log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
return
diff --git a/release_files/install.sh b/release_files/install.sh
index 856d332cb..5d5349ec4 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -130,36 +130,6 @@ repo_gpgcheck=1
EOF
}
-install_aur_package() {
- INSTALL_PKGS="git base-devel go"
- REMOVE_PKGS=""
-
- # Check if dependencies are installed
- for PKG in $INSTALL_PKGS; do
- if ! pacman -Q "$PKG" > /dev/null 2>&1; then
- # Install missing package(s)
- ${SUDO} pacman -S "$PKG" --noconfirm
-
- # Add installed package for clean up later
- REMOVE_PKGS="$REMOVE_PKGS $PKG"
- fi
- done
-
- # Build package from AUR
- cd /tmp && git clone https://aur.archlinux.org/netbird.git
- cd netbird && makepkg -sri --noconfirm
-
- if ! $SKIP_UI_APP; then
- cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git
- cd netbird-ui && makepkg -sri --noconfirm
- fi
-
- if [ -n "$REMOVE_PKGS" ]; then
- # Clean up the installed packages
- ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
- fi
-}
-
prepare_tun_module() {
# Create the necessary file structure for /dev/net/tun
if [ ! -c /dev/net/tun ]; then
@@ -276,12 +246,9 @@ install_netbird() {
if ! $SKIP_UI_APP; then
${SUDO} rpm-ostree -y install netbird-ui
fi
- ;;
- pacman)
- ${SUDO} pacman -Syy
- install_aur_package
- # in-line with the docs at https://wiki.archlinux.org/title/Netbird
- ${SUDO} systemctl enable --now netbird@main.service
+ # ensure the service is started after install
+ ${SUDO} netbird service install || true
+ ${SUDO} netbird service start || true
;;
pkg)
# Check if the package is already installed
@@ -458,11 +425,7 @@ if type uname >/dev/null 2>&1; then
elif [ -x "$(command -v yum)" ]; then
PACKAGE_MANAGER="yum"
echo "The installation will be performed using yum package manager"
- elif [ -x "$(command -v pacman)" ]; then
- PACKAGE_MANAGER="pacman"
- echo "The installation will be performed using pacman package manager"
fi
-
else
echo "Unable to determine OS type from /etc/os-release"
exit 1
diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go
index 3037b44bb..becc10ded 100644
--- a/shared/management/client/client_test.go
+++ b/shared/management/client/client_test.go
@@ -9,34 +9,30 @@ import (
"time"
"github.com/golang/mock/gomock"
- "github.com/stretchr/testify/require"
-
- "github.com/netbirdio/netbird/client/system"
- "github.com/netbirdio/netbird/management/internals/server/config"
- "github.com/netbirdio/netbird/management/server/activity"
- "github.com/netbirdio/netbird/management/server/groups"
- "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
- "github.com/netbirdio/netbird/management/server/permissions"
- "github.com/netbirdio/netbird/management/server/settings"
- "github.com/netbirdio/netbird/management/server/store"
- "github.com/netbirdio/netbird/management/server/telemetry"
- "github.com/netbirdio/netbird/management/server/types"
-
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
-
- "github.com/netbirdio/management-integrations/integrations"
-
- "github.com/netbirdio/netbird/encryption"
- mgmt "github.com/netbirdio/netbird/management/server"
- "github.com/netbirdio/netbird/management/server/mock_server"
- mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
-
+ "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
+ "github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/client/system"
+ "github.com/netbirdio/netbird/encryption"
+ "github.com/netbirdio/netbird/management/internals/server/config"
+ mgmt "github.com/netbirdio/netbird/management/server"
+ "github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/groups"
+ "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
+ "github.com/netbirdio/netbird/management/server/mock_server"
+ "github.com/netbirdio/netbird/management/server/peers"
+ "github.com/netbirdio/netbird/management/server/permissions"
+ "github.com/netbirdio/netbird/management/server/settings"
+ "github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/management/server/telemetry"
+ "github.com/netbirdio/netbird/management/server/types"
+ mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -72,13 +68,31 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
- ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
+
+ ctrl := gomock.NewController(t)
+ t.Cleanup(ctrl.Finish)
+
+ permissionsManagerMock := permissions.NewMockManager(ctrl)
+ permissionsManagerMock.
+ EXPECT().
+ ValidateUserPermissions(
+ gomock.Any(),
+ gomock.Any(),
+ gomock.Any(),
+ gomock.Any(),
+ gomock.Any(),
+ ).
+ Return(true, nil).
+ AnyTimes()
+
+ peersManger := peers.NewManager(store, permissionsManagerMock)
+ settingsManagerMock := settings.NewMockManager(ctrl)
+
+ ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
- ctrl := gomock.NewController(t)
- t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.
EXPECT().
@@ -95,19 +109,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
Return(&types.ExtraSettings{}, nil).
AnyTimes()
- permissionsManagerMock := permissions.NewMockManager(ctrl)
- permissionsManagerMock.
- EXPECT().
- ValidateUserPermissions(
- gomock.Any(),
- gomock.Any(),
- gomock.Any(),
- gomock.Any(),
- gomock.Any(),
- ).
- Return(true, nil).
- AnyTimes()
-
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go
index dc26253e9..f30e965be 100644
--- a/shared/management/client/grpc.go
+++ b/shared/management/client/grpc.go
@@ -17,11 +17,11 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
+ nbgrpc "github.com/netbirdio/netbird/client/grpc"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
- nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
const ConnectTimeout = 10 * time.Second
@@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
operation := func() error {
var err error
- conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
+ conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
diff --git a/shared/management/client/rest/client_test.go b/shared/management/client/rest/client_test.go
index 56c859652..54a0290d0 100644
--- a/shared/management/client/rest/client_test.go
+++ b/shared/management/client/rest/client_test.go
@@ -8,8 +8,8 @@ import (
"net/http/httptest"
"testing"
+ "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/client/rest"
- "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
)
func withMockClient(callback func(*rest.Client, *http.ServeMux)) {
@@ -26,7 +26,7 @@ func ptr[T any, PT *T](x T) PT {
func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) {
t.Helper()
- handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false)
+ handler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false)
server := httptest.NewServer(handler)
defer server.Close()
c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp")
diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml
index cf4b6d625..9a531b2ff 100644
--- a/shared/management/http/api/openapi.yml
+++ b/shared/management/http/api/openapi.yml
@@ -158,6 +158,10 @@ components:
description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
type: boolean
example: true
+ user_approval_required:
+ description: Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin.
+ type: boolean
+ example: false
network_traffic_logs_enabled:
description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
type: boolean
@@ -174,6 +178,7 @@ components:
example: true
required:
- peer_approval_enabled
+ - user_approval_required
- network_traffic_logs_enabled
- network_traffic_logs_groups
- network_traffic_packet_counter_enabled
@@ -235,6 +240,10 @@ components:
description: Is true if this user is blocked. Blocked users can't use the system
type: boolean
example: false
+ pending_approval:
+ description: Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled.
+ type: boolean
+ example: false
issued:
description: How user was issued by API or Integration
type: string
@@ -249,6 +258,7 @@ components:
- auto_groups
- status
- is_blocked
+ - pending_approval
UserPermissions:
type: object
properties:
@@ -2544,6 +2554,63 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
+ /api/users/{userId}/approve:
+ post:
+ summary: Approve user
+ description: Approve a user that is pending approval
+ tags: [ Users ]
+ security:
+ - BearerAuth: [ ]
+ - TokenAuth: [ ]
+ parameters:
+ - in: path
+ name: userId
+ required: true
+ schema:
+ type: string
+ description: The unique identifier of a user
+ responses:
+ '200':
+ description: Returns the approved user
+ content:
+ application/json:
+ schema:
+ "$ref": "#/components/schemas/User"
+ '400':
+ "$ref": "#/components/responses/bad_request"
+ '401':
+ "$ref": "#/components/responses/requires_authentication"
+ '403':
+ "$ref": "#/components/responses/forbidden"
+ '500':
+ "$ref": "#/components/responses/internal_error"
+ /api/users/{userId}/reject:
+ delete:
+ summary: Reject user
+ description: Reject a user that is pending approval by removing them from the account
+ tags: [ Users ]
+ security:
+ - BearerAuth: [ ]
+ - TokenAuth: [ ]
+ parameters:
+ - in: path
+ name: userId
+ required: true
+ schema:
+ type: string
+ description: The unique identifier of a user
+ responses:
+ '200':
+ description: User rejected successfully
+ content: {}
+ '400':
+ "$ref": "#/components/responses/bad_request"
+ '401':
+ "$ref": "#/components/responses/requires_authentication"
+ '403':
+ "$ref": "#/components/responses/forbidden"
+ '500':
+ "$ref": "#/components/responses/internal_error"
/api/users/current:
get:
summary: Retrieve current user
diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go
index cffc9e735..28b89633c 100644
--- a/shared/management/http/api/types.gen.go
+++ b/shared/management/http/api/types.gen.go
@@ -268,6 +268,9 @@ type AccountExtraSettings struct {
// PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
PeerApprovalEnabled bool `json:"peer_approval_enabled"`
+
+ // UserApprovalRequired Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin.
+ UserApprovalRequired bool `json:"user_approval_required"`
}
// AccountOnboarding defines model for AccountOnboarding.
@@ -1015,8 +1018,6 @@ type OSVersionCheck struct {
// Peer defines model for Peer.
type Peer struct {
- // CreatedAt Peer creation date (UTC)
- CreatedAt time.Time `json:"created_at"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`
@@ -1032,6 +1033,9 @@ type Peer struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
+ // CreatedAt Peer creation date (UTC)
+ CreatedAt time.Time `json:"created_at"`
+
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
@@ -1098,8 +1102,6 @@ type Peer struct {
// PeerBatch defines model for PeerBatch.
type PeerBatch struct {
- // CreatedAt Peer creation date (UTC)
- CreatedAt time.Time `json:"created_at"`
// AccessiblePeersCount Number of accessible peers
AccessiblePeersCount int `json:"accessible_peers_count"`
@@ -1118,6 +1120,9 @@ type PeerBatch struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
+ // CreatedAt Peer creation date (UTC)
+ CreatedAt time.Time `json:"created_at"`
+
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
@@ -1774,8 +1779,11 @@ type User struct {
LastLogin *time.Time `json:"last_login,omitempty"`
// Name User's name from idp provider
- Name string `json:"name"`
- Permissions *UserPermissions `json:"permissions,omitempty"`
+ Name string `json:"name"`
+
+ // PendingApproval Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled.
+ PendingApproval bool `json:"pending_approval"`
+ Permissions *UserPermissions `json:"permissions,omitempty"`
// Role User's NetBird account role
Role string `json:"role"`
diff --git a/shared/management/status/error.go b/shared/management/status/error.go
index 7660174d6..1e914babb 100644
--- a/shared/management/status/error.go
+++ b/shared/management/status/error.go
@@ -42,7 +42,10 @@ const (
// Type is a type of the Error
type Type int32
-var ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found")
+var (
+ ErrExtraSettingsNotFound = errors.New("extra settings not found")
+ ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in")
+)
// Error is an internal error
type Error struct {
@@ -110,6 +113,11 @@ func NewUserBlockedError() error {
return Errorf(PermissionDenied, "user is blocked")
}
+// NewUserPendingApprovalError creates a new Error with PermissionDenied type for a blocked user pending approval
+func NewUserPendingApprovalError() error {
+ return Errorf(PermissionDenied, "user is pending approval")
+}
+
// NewPeerNotRegisteredError creates a new Error with Unauthenticated type unregistered peer
func NewPeerNotRegisteredError() error {
return Errorf(Unauthenticated, "peer is not registered")
diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go
index b496f6a9b..967e18d79 100644
--- a/shared/relay/client/dialer/quic/quic.go
+++ b/shared/relay/client/dialer/quic/quic.go
@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
quictls "github.com/netbirdio/netbird/shared/relay/tls"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
type Dialer struct {
diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go
index 109651f5d..ef6bd6b3c 100644
--- a/shared/relay/client/dialer/ws/ws.go
+++ b/shared/relay/client/dialer/ws/ws.go
@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/shared/relay"
"github.com/netbirdio/netbird/util/embeddedroots"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
type Dialer struct {
diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go
index a40343fb1..6220e7f6b 100644
--- a/shared/relay/client/manager.go
+++ b/shared/relay/client/manager.go
@@ -78,9 +78,10 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
tokenStore: tokenStore,
mtu: mtu,
serverPicker: &ServerPicker{
- TokenStore: tokenStore,
- PeerID: peerID,
- MTU: mtu,
+ TokenStore: tokenStore,
+ PeerID: peerID,
+ MTU: mtu,
+ ConnectionTimeout: defaultConnectionTimeout,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
diff --git a/shared/relay/client/picker.go b/shared/relay/client/picker.go
index b6c7b5e8a..39d0ba072 100644
--- a/shared/relay/client/picker.go
+++ b/shared/relay/client/picker.go
@@ -13,11 +13,8 @@ import (
)
const (
- maxConcurrentServers = 7
-)
-
-var (
- connectionTimeout = 30 * time.Second
+ maxConcurrentServers = 7
+ defaultConnectionTimeout = 30 * time.Second
)
type connResult struct {
@@ -27,14 +24,15 @@ type connResult struct {
}
type ServerPicker struct {
- TokenStore *auth.TokenStore
- ServerURLs atomic.Value
- PeerID string
- MTU uint16
+ TokenStore *auth.TokenStore
+ ServerURLs atomic.Value
+ PeerID string
+ MTU uint16
+ ConnectionTimeout time.Duration
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
- ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
+ ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout)
defer cancel()
totalServers := len(sp.ServerURLs.Load().([]string))
diff --git a/shared/relay/client/picker_test.go b/shared/relay/client/picker_test.go
index 28167c5ce..fb3fa7375 100644
--- a/shared/relay/client/picker_test.go
+++ b/shared/relay/client/picker_test.go
@@ -8,15 +8,15 @@ import (
)
func TestServerPicker_UnavailableServers(t *testing.T) {
- connectionTimeout = 5 * time.Second
-
+ timeout := 5 * time.Second
sp := ServerPicker{
- TokenStore: nil,
- PeerID: "test",
+ TokenStore: nil,
+ PeerID: "test",
+ ConnectionTimeout: timeout,
}
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
- ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout+1)
defer cancel()
go func() {
diff --git a/shared/relay/healthcheck/env.go b/shared/relay/healthcheck/env.go
new file mode 100644
index 000000000..2b584c195
--- /dev/null
+++ b/shared/relay/healthcheck/env.go
@@ -0,0 +1,24 @@
+package healthcheck
+
+import (
+ "os"
+ "strconv"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
+)
+
+func getAttemptThresholdFromEnv() int {
+ if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
+ threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
+ if err != nil {
+ log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
+ return defaultAttemptThreshold
+ }
+ return int(threshold)
+ }
+ return defaultAttemptThreshold
+}
diff --git a/shared/relay/healthcheck/env_test.go b/shared/relay/healthcheck/env_test.go
new file mode 100644
index 000000000..2e14bb8bf
--- /dev/null
+++ b/shared/relay/healthcheck/env_test.go
@@ -0,0 +1,36 @@
+package healthcheck
+
+import (
+ "os"
+ "testing"
+)
+
+//nolint:tenv
+func TestGetAttemptThresholdFromEnv(t *testing.T) {
+ tests := []struct {
+ name string
+ envValue string
+ expected int
+ }{
+ {"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
+ {"Custom attempt threshold when env is set to a valid integer", "3", 3},
+ {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.envValue == "" {
+ os.Unsetenv(defaultAttemptThresholdEnv)
+ } else {
+ os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
+ }
+
+ result := getAttemptThresholdFromEnv()
+ if result != tt.expected {
+ t.Fatalf("Expected %d, got %d", tt.expected, result)
+ }
+
+ os.Unsetenv(defaultAttemptThresholdEnv)
+ })
+ }
+}
diff --git a/shared/relay/healthcheck/receiver.go b/shared/relay/healthcheck/receiver.go
index b3503d5db..90f795bbe 100644
--- a/shared/relay/healthcheck/receiver.go
+++ b/shared/relay/healthcheck/receiver.go
@@ -7,10 +7,15 @@ import (
log "github.com/sirupsen/logrus"
)
-var (
- heartbeatTimeout = healthCheckInterval + 10*time.Second
+const (
+ defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second
)
+type ReceiverOptions struct {
+ HeartbeatTimeout time.Duration
+ AttemptThreshold int
+}
+
// Receiver is a healthcheck receiver
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
@@ -27,6 +32,23 @@ type Receiver struct {
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver(log *log.Entry) *Receiver {
+ opts := ReceiverOptions{
+ HeartbeatTimeout: defaultHeartbeatTimeout,
+ AttemptThreshold: getAttemptThresholdFromEnv(),
+ }
+ return NewReceiverWithOpts(log, opts)
+}
+
+func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
+ heartbeatTimeout := opts.HeartbeatTimeout
+ if heartbeatTimeout <= 0 {
+ heartbeatTimeout = defaultHeartbeatTimeout
+ }
+ attemptThreshold := opts.AttemptThreshold
+ if attemptThreshold <= 0 {
+ attemptThreshold = defaultAttemptThreshold
+ }
+
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
@@ -35,10 +57,10 @@ func NewReceiver(log *log.Entry) *Receiver {
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
- attemptThreshold: getAttemptThresholdFromEnv(),
+ attemptThreshold: attemptThreshold,
}
- go r.waitForHealthcheck()
+ go r.waitForHealthcheck(heartbeatTimeout)
return r
}
@@ -55,7 +77,7 @@ func (r *Receiver) Stop() {
r.ctxCancel()
}
-func (r *Receiver) waitForHealthcheck() {
+func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) {
ticker := time.NewTicker(heartbeatTimeout)
defer ticker.Stop()
defer r.ctxCancel()
diff --git a/shared/relay/healthcheck/receiver_test.go b/shared/relay/healthcheck/receiver_test.go
index 2794159f6..b20cc5124 100644
--- a/shared/relay/healthcheck/receiver_test.go
+++ b/shared/relay/healthcheck/receiver_test.go
@@ -2,31 +2,18 @@ package healthcheck
import (
"context"
- "fmt"
- "os"
- "sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
-// Mutex to protect global variable access in tests
-var testMutex sync.Mutex
-
func TestNewReceiver(t *testing.T) {
- testMutex.Lock()
- originalTimeout := heartbeatTimeout
- heartbeatTimeout = 5 * time.Second
- testMutex.Unlock()
- defer func() {
- testMutex.Lock()
- heartbeatTimeout = originalTimeout
- testMutex.Unlock()
- }()
-
- r := NewReceiver(log.WithContext(context.Background()))
+ opts := ReceiverOptions{
+ HeartbeatTimeout: 5 * time.Second,
+ }
+ r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
defer r.Stop()
select {
@@ -38,18 +25,10 @@ func TestNewReceiver(t *testing.T) {
}
func TestNewReceiverNotReceive(t *testing.T) {
- testMutex.Lock()
- originalTimeout := heartbeatTimeout
- heartbeatTimeout = 1 * time.Second
- testMutex.Unlock()
-
- defer func() {
- testMutex.Lock()
- heartbeatTimeout = originalTimeout
- testMutex.Unlock()
- }()
-
- r := NewReceiver(log.WithContext(context.Background()))
+ opts := ReceiverOptions{
+ HeartbeatTimeout: 1 * time.Second,
+ }
+ r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
defer r.Stop()
select {
@@ -61,18 +40,10 @@ func TestNewReceiverNotReceive(t *testing.T) {
}
func TestNewReceiverAck(t *testing.T) {
- testMutex.Lock()
- originalTimeout := heartbeatTimeout
- heartbeatTimeout = 2 * time.Second
- testMutex.Unlock()
-
- defer func() {
- testMutex.Lock()
- heartbeatTimeout = originalTimeout
- testMutex.Unlock()
- }()
-
- r := NewReceiver(log.WithContext(context.Background()))
+ opts := ReceiverOptions{
+ HeartbeatTimeout: 2 * time.Second,
+ }
+ r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
defer r.Stop()
r.Heartbeat()
@@ -97,30 +68,19 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
- testMutex.Lock()
- originalInterval := healthCheckInterval
- originalTimeout := heartbeatTimeout
- healthCheckInterval = 1 * time.Second
- heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
- testMutex.Unlock()
+ healthCheckInterval := 1 * time.Second
- defer func() {
- testMutex.Lock()
- healthCheckInterval = originalInterval
- heartbeatTimeout = originalTimeout
- testMutex.Unlock()
- }()
- //nolint:tenv
- os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
- defer os.Unsetenv(defaultAttemptThresholdEnv)
+ opts := ReceiverOptions{
+ HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond,
+ AttemptThreshold: tc.threshold,
+ }
- receiver := NewReceiver(log.WithField("test_name", tc.name))
+ receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts)
- testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
+ testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce {
receiver.Heartbeat()
- t.Logf("reset counter once")
}
select {
@@ -134,7 +94,6 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
}
t.Fatalf("should have timed out before %s", testTimeout)
}
-
})
}
}
diff --git a/shared/relay/healthcheck/sender.go b/shared/relay/healthcheck/sender.go
index 57b3015ec..771e94206 100644
--- a/shared/relay/healthcheck/sender.go
+++ b/shared/relay/healthcheck/sender.go
@@ -2,52 +2,76 @@ package healthcheck
import (
"context"
- "os"
- "strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
- defaultAttemptThreshold = 1
- defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
+ defaultAttemptThreshold = 1
+
+ defaultHealthCheckInterval = 25 * time.Second
+ defaultHealthCheckTimeout = 20 * time.Second
)
-var (
- healthCheckInterval = 25 * time.Second
- healthCheckTimeout = 20 * time.Second
-)
+type SenderOptions struct {
+ HealthCheckInterval time.Duration
+ HealthCheckTimeout time.Duration
+ AttemptThreshold int
+}
// Sender is a healthcheck sender
// It will send healthcheck signal to the receiver
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
- log *log.Entry
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
+ log *log.Entry
+ healthCheckInterval time.Duration
+ timeout time.Duration
+
ack chan struct{}
alive bool
attemptThreshold int
}
-// NewSender creates a new healthcheck sender
-func NewSender(log *log.Entry) *Sender {
+func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender {
+ if opts.HealthCheckInterval <= 0 {
+ opts.HealthCheckInterval = defaultHealthCheckInterval
+ }
+ if opts.HealthCheckTimeout <= 0 {
+ opts.HealthCheckTimeout = defaultHealthCheckTimeout
+ }
+ if opts.AttemptThreshold <= 0 {
+ opts.AttemptThreshold = defaultAttemptThreshold
+ }
hc := &Sender{
- log: log,
- HealthCheck: make(chan struct{}, 1),
- Timeout: make(chan struct{}, 1),
- ack: make(chan struct{}, 1),
- attemptThreshold: getAttemptThresholdFromEnv(),
+ HealthCheck: make(chan struct{}, 1),
+ Timeout: make(chan struct{}, 1),
+ log: log,
+ healthCheckInterval: opts.HealthCheckInterval,
+ timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout,
+ ack: make(chan struct{}, 1),
+ attemptThreshold: opts.AttemptThreshold,
}
return hc
}
+// NewSender creates a new healthcheck sender
+func NewSender(log *log.Entry) *Sender {
+ opts := SenderOptions{
+ HealthCheckInterval: defaultHealthCheckInterval,
+ HealthCheckTimeout: defaultHealthCheckTimeout,
+ AttemptThreshold: getAttemptThresholdFromEnv(),
+ }
+ return NewSenderWithOpts(log, opts)
+}
+
// OnHCResponse sends an acknowledgment signal to the sender
func (hc *Sender) OnHCResponse() {
select {
@@ -57,10 +81,10 @@ func (hc *Sender) OnHCResponse() {
}
func (hc *Sender) StartHealthCheck(ctx context.Context) {
- ticker := time.NewTicker(healthCheckInterval)
+ ticker := time.NewTicker(hc.healthCheckInterval)
defer ticker.Stop()
- timeoutTicker := time.NewTicker(hc.getTimeoutTime())
+ timeoutTicker := time.NewTicker(hc.timeout)
defer timeoutTicker.Stop()
defer close(hc.HealthCheck)
@@ -92,19 +116,3 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
}
}
}
-
-func (hc *Sender) getTimeoutTime() time.Duration {
- return healthCheckInterval + healthCheckTimeout
-}
-
-func getAttemptThresholdFromEnv() int {
- if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
- threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
- if err != nil {
- log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
- return defaultAttemptThreshold
- }
- return int(threshold)
- }
- return defaultAttemptThreshold
-}
diff --git a/shared/relay/healthcheck/sender_test.go b/shared/relay/healthcheck/sender_test.go
index 23446366a..122fe0f16 100644
--- a/shared/relay/healthcheck/sender_test.go
+++ b/shared/relay/healthcheck/sender_test.go
@@ -2,26 +2,23 @@ package healthcheck
import (
"context"
- "fmt"
- "os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
-func TestMain(m *testing.M) {
- // override the health check interval to speed up the test
- healthCheckInterval = 2 * time.Second
- healthCheckTimeout = 100 * time.Millisecond
- code := m.Run()
- os.Exit(code)
-}
+var (
+ testOpts = SenderOptions{
+ HealthCheckInterval: 2 * time.Second,
+ HealthCheckTimeout: 100 * time.Millisecond,
+ }
+)
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- hc := NewSender(log.WithContext(ctx))
+ hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -32,7 +29,7 @@ func TestNewHealthPeriod(t *testing.T) {
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
- case <-time.After(healthCheckInterval + 100*time.Millisecond):
+ case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
@@ -41,19 +38,19 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- hc := NewSender(log.WithContext(ctx))
+ hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
go hc.StartHealthCheck(ctx)
select {
case <-hc.Timeout:
- case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
+ case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond):
t.Fatalf("health check is not timed out")
}
}
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
- hc := NewSender(log.WithContext(ctx))
+ hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
@@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- hc := NewSender(log.WithContext(ctx))
+ hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -89,7 +86,7 @@ func TestTimeoutReset(t *testing.T) {
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
- case <-time.After(healthCheckInterval + 100*time.Millisecond):
+ case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
@@ -118,19 +115,16 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
- originalInterval := healthCheckInterval
- originalTimeout := healthCheckTimeout
- healthCheckInterval = 1 * time.Second
- healthCheckTimeout = 500 * time.Millisecond
-
- //nolint:tenv
- os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
- defer os.Unsetenv(defaultAttemptThresholdEnv)
+ opts := SenderOptions{
+ HealthCheckInterval: 1 * time.Second,
+ HealthCheckTimeout: 500 * time.Millisecond,
+ AttemptThreshold: tc.threshold,
+ }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- sender := NewSender(log.WithField("test_name", tc.name))
+ sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts)
senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
@@ -155,7 +149,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
}
}()
- testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
+ testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval
select {
case <-sender.Timeout:
@@ -175,39 +169,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
- healthCheckInterval = originalInterval
- healthCheckTimeout = originalTimeout
})
}
}
-
-//nolint:tenv
-func TestGetAttemptThresholdFromEnv(t *testing.T) {
- tests := []struct {
- name string
- envValue string
- expected int
- }{
- {"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
- {"Custom attempt threshold when env is set to a valid integer", "3", 3},
- {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if tt.envValue == "" {
- os.Unsetenv(defaultAttemptThresholdEnv)
- } else {
- os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
- }
-
- result := getAttemptThresholdFromEnv()
- if result != tt.expected {
- t.Fatalf("Expected %d, got %d", tt.expected, result)
- }
-
- os.Unsetenv(defaultAttemptThresholdEnv)
- })
- }
-}
diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go
index 82ab678f4..5ca0c0282 100644
--- a/shared/signal/client/grpc.go
+++ b/shared/signal/client/grpc.go
@@ -16,10 +16,10 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
+ nbgrpc "github.com/netbirdio/netbird/client/grpc"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/signal/proto"
- nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
// ConnStateNotifier is a wrapper interface of the status recorder
@@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
operation := func() error {
var err error
- conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
+ conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go
index db428515b..bc2d4d1be 100644
--- a/sharedsock/sock_linux.go
+++ b/sharedsock/sock_linux.go
@@ -22,7 +22,7 @@ import (
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"
- nbnet "github.com/netbirdio/netbird/util/net"
+ nbnet "github.com/netbirdio/netbird/client/net"
)
// ErrSharedSockStopped indicates that shared socket has been stopped
@@ -93,7 +93,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
}
if err = nbnet.SetSocketMark(rawSock.conn4); err != nil {
- return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err)
+ return nil, fmt.Errorf("set SO_MARK on ipv4 socket: %w", err)
}
var sockErr error
@@ -102,7 +102,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
log.Errorf("Failed to create ipv6 raw socket: %v", err)
} else {
if err = nbnet.SetSocketMark(rawSock.conn6); err != nil {
- return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err)
+ return nil, fmt.Errorf("set SO_MARK on ipv6 socket: %w", err)
}
}
@@ -230,10 +230,8 @@ func (s *SharedSocket) Close() error {
// read start a read loop for a specific receiver and sends the packet to the packetDemux channel
func (s *SharedSocket) read(receiver receiver) {
- // Buffer reuse is safe: packetDemux is unbuffered, so read() blocks until
- // ReadFrom() synchronously processes the packet before next iteration
- buf := make([]byte, s.mtu+maxIPUDPOverhead)
for {
+ buf := make([]byte, s.mtu+maxIPUDPOverhead)
n, addr, err := receiver(s.ctx, buf, 0)
select {
case <-s.ctx.Done():
diff --git a/signal/cmd/run.go b/signal/cmd/run.go
index 2e89b491a..1d76fa4e4 100644
--- a/signal/cmd/run.go
+++ b/signal/cmd/run.go
@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/http"
+
// nolint:gosec
_ "net/http/pprof"
"strings"
diff --git a/signal/peer/peer.go b/signal/peer/peer.go
index f21c95a41..c9dd60fc0 100644
--- a/signal/peer/peer.go
+++ b/signal/peer/peer.go
@@ -5,10 +5,16 @@ import (
"sync"
"time"
+ "errors"
+
log "github.com/sirupsen/logrus"
- "github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/shared/signal/proto"
+ "github.com/netbirdio/netbird/signal/metrics"
+)
+
+var (
+ ErrPeerAlreadyRegistered = errors.New("peer already registered")
)
// Peer representation of a connected Peer
@@ -23,15 +29,18 @@ type Peer struct {
// registration time
RegisteredAt time.Time
+
+ Cancel context.CancelFunc
}
// NewPeer creates a new instance of a connected Peer
-func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer {
+func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
return &Peer{
Id: id,
Stream: stream,
StreamID: time.Now().UnixNano(),
RegisteredAt: time.Now(),
+ Cancel: cancel,
}
}
@@ -69,20 +78,24 @@ func (registry *Registry) IsPeerRegistered(peerId string) bool {
}
// Register registers peer in the registry
-func (registry *Registry) Register(peer *Peer) {
+func (registry *Registry) Register(peer *Peer) error {
start := time.Now()
- registry.regMutex.Lock()
- defer registry.regMutex.Unlock()
-
// can be that peer already exists, but it is fine (e.g. reconnect)
p, loaded := registry.Peers.LoadOrStore(peer.Id, peer)
if loaded {
pp := p.(*Peer)
- log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
- peer.Id, peer.StreamID, pp.StreamID)
- registry.Peers.Store(peer.Id, peer)
- return
+ if peer.StreamID > pp.StreamID {
+ log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
+ peer.Id, peer.StreamID, pp.StreamID)
+ if swapped := registry.Peers.CompareAndSwap(peer.Id, pp, peer); !swapped {
+ return registry.Register(peer)
+ }
+ pp.Cancel()
+ log.Debugf("peer re-registered [%s]", peer.Id)
+ return nil
+ }
+ return ErrPeerAlreadyRegistered
}
log.Debugf("peer registered [%s]", peer.Id)
@@ -92,22 +105,13 @@ func (registry *Registry) Register(peer *Peer) {
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
registry.metrics.Registrations.Add(context.Background(), 1)
+
+ return nil
}
// Deregister Peer from the Registry (usually once it disconnects)
func (registry *Registry) Deregister(peer *Peer) {
- registry.regMutex.Lock()
- defer registry.regMutex.Unlock()
-
- p, loaded := registry.Peers.LoadAndDelete(peer.Id)
- if loaded {
- pp := p.(*Peer)
- if peer.StreamID < pp.StreamID {
- registry.Peers.Store(peer.Id, p)
- log.Debugf("attempted to remove newer registered stream of a peer [%s] [newer streamID %d, previous StreamID %d]. Ignoring.",
- peer.Id, pp.StreamID, peer.StreamID)
- return
- }
+ if deleted := registry.Peers.CompareAndDelete(peer.Id, peer); deleted {
registry.metrics.ActivePeers.Add(context.Background(), -1)
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
diff --git a/signal/peer/peer_test.go b/signal/peer/peer_test.go
index fb85fedda..6b7976eb4 100644
--- a/signal/peer/peer_test.go
+++ b/signal/peer/peer_test.go
@@ -1,13 +1,18 @@
package peer
import (
+ "context"
+ "sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+ "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/metrics"
)
@@ -19,12 +24,16 @@ func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T)
peerID := "peer"
- olderPeer := NewPeer(peerID, nil)
- r.Register(olderPeer)
+ _, cancel1 := context.WithCancel(context.Background())
+ olderPeer := NewPeer(peerID, nil, cancel1)
+ err = r.Register(olderPeer)
+ require.NoError(t, err)
time.Sleep(time.Nanosecond)
- newerPeer := NewPeer(peerID, nil)
- r.Register(newerPeer)
+ _, cancel2 := context.WithCancel(context.Background())
+ newerPeer := NewPeer(peerID, nil, cancel2)
+ err = r.Register(newerPeer)
+ require.NoError(t, err)
registered, _ := r.Get(olderPeer.Id)
assert.NotNil(t, registered, "peer can't be nil")
@@ -59,10 +68,14 @@ func TestRegistry_Register(t *testing.T) {
require.NoError(t, err)
r := NewRegistry(metrics)
- peer1 := NewPeer("test_peer_1", nil)
- peer2 := NewPeer("test_peer_2", nil)
- r.Register(peer1)
- r.Register(peer2)
+ _, cancel1 := context.WithCancel(context.Background())
+ peer1 := NewPeer("test_peer_1", nil, cancel1)
+ _, cancel2 := context.WithCancel(context.Background())
+ peer2 := NewPeer("test_peer_2", nil, cancel2)
+ err = r.Register(peer1)
+ require.NoError(t, err)
+ err = r.Register(peer2)
+ require.NoError(t, err)
if _, ok := r.Get("test_peer_1"); !ok {
t.Errorf("expected test_peer_1 not found in the registry")
@@ -78,10 +91,14 @@ func TestRegistry_Deregister(t *testing.T) {
require.NoError(t, err)
r := NewRegistry(metrics)
- peer1 := NewPeer("test_peer_1", nil)
- peer2 := NewPeer("test_peer_2", nil)
- r.Register(peer1)
- r.Register(peer2)
+ _, cancel1 := context.WithCancel(context.Background())
+ peer1 := NewPeer("test_peer_1", nil, cancel1)
+ _, cancel2 := context.WithCancel(context.Background())
+ peer2 := NewPeer("test_peer_2", nil, cancel2)
+ err = r.Register(peer1)
+ require.NoError(t, err)
+ err = r.Register(peer2)
+ require.NoError(t, err)
r.Deregister(peer1)
@@ -94,3 +111,213 @@ func TestRegistry_Deregister(t *testing.T) {
}
}
+
+func TestRegistry_MultipleRegister_Concurrency(t *testing.T) {
+
+ metrics, err := metrics.NewAppMetrics(otel.Meter(""))
+ require.NoError(t, err)
+ registry := NewRegistry(metrics)
+
+ numGoroutines := 1000
+
+ ids := make(chan int64, numGoroutines)
+
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+ peerID := "peer-concurrent"
+ for i := range numGoroutines {
+ go func(routineIndex int) {
+ defer wg.Done()
+
+ _, cancel := context.WithCancel(context.Background())
+ peer := NewPeer(peerID, nil, cancel)
+ _ = registry.Register(peer)
+ ids <- peer.StreamID
+ }(i)
+ }
+
+ wg.Wait()
+ close(ids)
+ maxId := int64(0)
+ for id := range ids {
+ maxId = max(maxId, id)
+ }
+
+ peer, ok := registry.Get(peerID)
+ require.True(t, ok, "expected peer to be registered")
+ require.Equal(t, maxId, peer.StreamID, "expected the highest StreamID to be registered")
+}
+
+func Benchmark_MultipleRegister_Concurrency(b *testing.B) {
+
+ metrics, err := metrics.NewAppMetrics(otel.Meter(""))
+ require.NoError(b, err)
+
+ numGoroutines := 1000
+
+ var wg sync.WaitGroup
+ peerID := "peer-concurrent"
+ _, cancel := context.WithCancel(context.Background())
+ b.Run("multiple-register", func(b *testing.B) {
+ registry := NewRegistry(metrics)
+ b.ResetTimer()
+ for j := 0; j < b.N; j++ {
+ wg.Add(numGoroutines)
+ for i := range numGoroutines {
+ go func(routineIndex int) {
+ defer wg.Done()
+
+ peer := NewPeer(peerID, nil, cancel)
+ _ = registry.Register(peer)
+ }(i)
+ }
+ wg.Wait()
+ }
+ })
+}
+
+func TestRegistry_MultipleDeregister_Concurrency(t *testing.T) {
+
+ metrics, err := metrics.NewAppMetrics(otel.Meter(""))
+ require.NoError(t, err)
+ registry := NewRegistry(metrics)
+
+ numGoroutines := 1000
+
+ ids := make(chan int64, numGoroutines)
+
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+ peerID := "peer-concurrent"
+ for i := range numGoroutines {
+ go func(routineIndex int) {
+ defer wg.Done()
+
+ _, cancel := context.WithCancel(context.Background())
+ peer := NewPeer(peerID, nil, cancel)
+ _ = registry.Register(peer)
+ ids <- peer.StreamID
+ registry.Deregister(peer)
+ }(i)
+ }
+
+ wg.Wait()
+ close(ids)
+ maxId := int64(0)
+ for id := range ids {
+ maxId = max(maxId, id)
+ }
+
+ _, ok := registry.Get(peerID)
+ require.False(t, ok, "expected peer to be deregistered")
+}
+
+func Benchmark_MultipleDeregister_Concurrency(b *testing.B) {
+
+ metrics, err := metrics.NewAppMetrics(otel.Meter(""))
+ require.NoError(b, err)
+
+ numGoroutines := 1000
+
+ var wg sync.WaitGroup
+ peerID := "peer-concurrent"
+ _, cancel := context.WithCancel(context.Background())
+ b.Run("register-deregister", func(b *testing.B) {
+ registry := NewRegistry(metrics)
+ b.ResetTimer()
+ for j := 0; j < b.N; j++ {
+ wg.Add(numGoroutines)
+ for i := range numGoroutines {
+ go func(routineIndex int) {
+ defer wg.Done()
+
+ peer := NewPeer(peerID, nil, cancel)
+ _ = registry.Register(peer)
+ time.Sleep(time.Nanosecond)
+ registry.Deregister(peer)
+ }(i)
+ }
+ wg.Wait()
+ }
+ })
+}
+
+type mockConnectStreamServer struct {
+ grpc.ServerStream
+ ctx context.Context
+}
+
+func (m *mockConnectStreamServer) Context() context.Context {
+ return m.ctx
+}
+
+func (m *mockConnectStreamServer) SendHeader(md metadata.MD) error {
+ return nil
+}
+
+func (m *mockConnectStreamServer) Send(msg *proto.EncryptedMessage) error {
+ return nil
+}
+
+func (m *mockConnectStreamServer) Recv() (*proto.EncryptedMessage, error) {
+ <-m.ctx.Done()
+ return nil, m.ctx.Err()
+}
+
+func TestReconnectHandling(t *testing.T) {
+ metrics, err := metrics.NewAppMetrics(otel.Meter(""))
+ require.NoError(t, err)
+ registry := NewRegistry(metrics)
+ peerID := "test-peer-reconnect"
+
+ ctx1, cancel1 := context.WithCancel(context.Background())
+ defer cancel1()
+ stream1 := &mockConnectStreamServer{ctx: ctx1}
+ peer1 := NewPeer(peerID, stream1, cancel1)
+
+ err = registry.Register(peer1)
+ require.NoError(t, err, "first registration should succeed")
+
+ p, found := registry.Get(peerID)
+ require.True(t, found, "peer should be found in the registry")
+ require.Equal(t, peer1.StreamID, p.StreamID, "StreamID of registered peer should match")
+
+ time.Sleep(time.Nanosecond)
+ ctx2, cancel2 := context.WithCancel(context.Background())
+ defer cancel2()
+ stream2 := &mockConnectStreamServer{ctx: ctx2}
+ peer2 := NewPeer(peerID, stream2, cancel2)
+
+ err = registry.Register(peer2)
+ require.NoError(t, err, "reconnect registration should succeed")
+
+ select {
+ case <-ctx1.Done():
+ require.ErrorIs(t, ctx1.Err(), context.Canceled, "context of old stream should be canceled after successful reconnection")
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("context of old stream was not canceled after reconnection")
+ }
+
+ p, found = registry.Get(peerID)
+ require.True(t, found)
+ require.Equal(t, peer2.StreamID, p.StreamID, "registered peer should have the new StreamID after reconnection")
+
+ ctx3, cancel3 := context.WithCancel(context.Background())
+ defer cancel3()
+ stream3 := &mockConnectStreamServer{ctx: ctx3}
+ stalePeer := NewPeer(peerID, stream3, cancel3)
+ stalePeer.StreamID = peer1.StreamID
+
+ err = registry.Register(stalePeer)
+ require.ErrorIs(t, err, ErrPeerAlreadyRegistered, "reconnecting with an old StreamID should return an error")
+
+ p, found = registry.Get(peerID)
+ require.True(t, found)
+ require.Equal(t, peer2.StreamID, p.StreamID, "active peer should still be the one with the latest StreamID")
+
+ select {
+ case <-ctx2.Done():
+ t.Fatal("context of the new stream should not be canceled after trying to register with an old StreamID")
+ default:
+ }
+}
diff --git a/signal/server/signal.go b/signal/server/signal.go
index 8ae14822b..47f01edae 100644
--- a/signal/server/signal.go
+++ b/signal/server/signal.go
@@ -2,7 +2,9 @@ package server
import (
"context"
+ "errors"
"fmt"
+ "os"
"time"
log "github.com/sirupsen/logrus"
@@ -15,9 +17,9 @@ import (
"github.com/netbirdio/signal-dispatcher/dispatcher"
+ "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
- "github.com/netbirdio/netbird/shared/signal/proto"
)
const (
@@ -27,6 +29,8 @@ const (
labelTypeNotRegistered = "not_registered"
labelTypeStream = "stream"
labelTypeMessage = "message"
+ labelTypeTimeout = "timeout"
+ labelTypeDisconnected = "disconnected"
labelError = "error"
labelErrorMissingId = "missing_id"
@@ -37,6 +41,12 @@ const (
labelRegistrationStatus = "status"
labelRegistrationFound = "found"
labelRegistrationNotFound = "not_found"
+
+ sendTimeout = 10 * time.Second
+)
+
+var (
+ ErrPeerRegisteredAgain = errors.New("peer registered again")
)
// Server an instance of a Signal server
@@ -45,6 +55,10 @@ type Server struct {
proto.UnimplementedSignalExchangeServer
dispatcher *dispatcher.Dispatcher
metrics *metrics.AppMetrics
+
+ successHeader metadata.MD
+
+ sendTimeout time.Duration
}
// NewServer creates a new Signal server
@@ -59,10 +73,19 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
return nil, fmt.Errorf("creating dispatcher: %v", err)
}
+ sTimeout := sendTimeout
+ to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT")
+ if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 {
+ log.Trace("using custom send timeout ", parsed)
+ sTimeout = parsed
+ }
+
s := &Server{
- dispatcher: d,
- registry: peer.NewRegistry(appMetrics),
- metrics: appMetrics,
+ dispatcher: d,
+ registry: peer.NewRegistry(appMetrics),
+ metrics: appMetrics,
+ successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
+ sendTimeout: sTimeout,
}
return s, nil
@@ -82,7 +105,8 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
// ConnectStream connects to the exchange stream
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
- p, err := s.RegisterPeer(stream)
+ ctx, cancel := context.WithCancel(context.Background())
+ p, err := s.RegisterPeer(stream, cancel)
if err != nil {
return err
}
@@ -90,8 +114,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
defer s.DeregisterPeer(p)
// needed to confirm that the peer has been registered so that the client can proceed
- header := metadata.Pairs(proto.HeaderRegistered, "1")
- err = stream.SendHeader(header)
+ err = stream.SendHeader(s.successHeader)
if err != nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader)))
return err
@@ -99,27 +122,27 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
- <-stream.Context().Done()
- log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID)
- return nil
+ select {
+ case <-stream.Context().Done():
+ log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID)
+ return nil
+ case <-ctx.Done():
+ return ErrPeerRegisteredAgain
+ }
}
-func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
+func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) {
log.Debugf("registering new peer")
- meta, hasMeta := metadata.FromIncomingContext(stream.Context())
- if !hasMeta {
- s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta)))
- return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
- }
-
- id, found := meta[proto.HeaderId]
- if !found {
+ id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId)
+ if id == nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId)
}
- p := peer.NewPeer(id[0], stream)
- s.registry.Register(p)
+ p := peer.NewPeer(id[0], stream, cancel)
+ if err := s.registry.Register(p); err != nil {
+ return nil, err
+ }
err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
if err != nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration)))
@@ -131,8 +154,8 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
func (s *Server) DeregisterPeer(p *peer.Peer) {
log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
- s.registry.Deregister(p)
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
+ s.registry.Deregister(p)
}
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
@@ -145,7 +168,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
if !found {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
- log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
+ log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
// todo respond to the sender?
return
}
@@ -153,16 +176,34 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()
- // forward the message to the target peer
- if err := dstPeer.Stream.Send(msg); err != nil {
- log.Tracef("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
- // todo respond to the sender?
- s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
- return
- }
+ sendResultChan := make(chan error, 1)
+ go func() {
+ select {
+ case sendResultChan <- dstPeer.Stream.Send(msg):
+ return
+ case <-dstPeer.Stream.Context().Done():
+ return
+ }
+ }()
- // in milliseconds
- s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
- s.metrics.MessagesForwarded.Add(ctx, 1)
- s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
+ select {
+ case err := <-sendResultChan:
+ if err != nil {
+ log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err)
+ s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
+ return
+ }
+ s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
+ s.metrics.MessagesForwarded.Add(ctx, 1)
+ s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
+
+ case <-dstPeer.Stream.Context().Done():
+ log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey)
+ s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected)))
+
+ case <-time.After(s.sendTimeout):
+ dstPeer.Cancel() // cancel the peer context to trigger deregistration
+ log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey)
+ s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout)))
+ }
}
diff --git a/util/net/conn.go b/util/net/conn.go
deleted file mode 100644
index 26693f841..000000000
--- a/util/net/conn.go
+++ /dev/null
@@ -1,31 +0,0 @@
-//go:build !ios
-
-package net
-
-import (
- "net"
-
- log "github.com/sirupsen/logrus"
-)
-
-// Conn wraps a net.Conn to override the Close method
-type Conn struct {
- net.Conn
- ID ConnectionID
-}
-
-// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
-func (c *Conn) Close() error {
- err := c.Conn.Close()
-
- dialerCloseHooksMutex.RLock()
- defer dialerCloseHooksMutex.RUnlock()
-
- for _, hook := range dialerCloseHooks {
- if err := hook(c.ID, &c.Conn); err != nil {
- log.Errorf("Error executing dialer close hook: %v", err)
- }
- }
-
- return err
-}
diff --git a/util/net/dial.go b/util/net/dial.go
deleted file mode 100644
index 595311492..000000000
--- a/util/net/dial.go
+++ /dev/null
@@ -1,58 +0,0 @@
-//go:build !ios
-
-package net
-
-import (
- "fmt"
- "net"
-
- log "github.com/sirupsen/logrus"
-)
-
-func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
- if CustomRoutingDisabled() {
- return net.DialUDP(network, laddr, raddr)
- }
-
- dialer := NewDialer()
- dialer.LocalAddr = laddr
-
- conn, err := dialer.Dial(network, raddr.String())
- if err != nil {
- return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
- }
-
- udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
- if !ok {
- if err := conn.Close(); err != nil {
- log.Errorf("Failed to close connection: %v", err)
- }
- return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
- }
-
- return udpConn, nil
-}
-
-func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
- if CustomRoutingDisabled() {
- return net.DialTCP(network, laddr, raddr)
- }
-
- dialer := NewDialer()
- dialer.LocalAddr = laddr
-
- conn, err := dialer.Dial(network, raddr.String())
- if err != nil {
- return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
- }
-
- tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
- if !ok {
- if err := conn.Close(); err != nil {
- log.Errorf("Failed to close connection: %v", err)
- }
- return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
- }
-
- return tcpConn, nil
-}
diff --git a/util/net/dialer_dial.go b/util/net/dialer_dial.go
deleted file mode 100644
index 1659b6220..000000000
--- a/util/net/dialer_dial.go
+++ /dev/null
@@ -1,107 +0,0 @@
-//go:build !ios
-
-package net
-
-import (
- "context"
- "fmt"
- "net"
- "sync"
-
- "github.com/hashicorp/go-multierror"
- log "github.com/sirupsen/logrus"
-)
-
-type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
-type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
-
-var (
- dialerDialHooksMutex sync.RWMutex
- dialerDialHooks []DialerDialHookFunc
- dialerCloseHooksMutex sync.RWMutex
- dialerCloseHooks []DialerCloseHookFunc
-)
-
-// AddDialerHook allows adding a new hook to be executed before dialing.
-func AddDialerHook(hook DialerDialHookFunc) {
- dialerDialHooksMutex.Lock()
- defer dialerDialHooksMutex.Unlock()
- dialerDialHooks = append(dialerDialHooks, hook)
-}
-
-// AddDialerCloseHook allows adding a new hook to be executed on connection close.
-func AddDialerCloseHook(hook DialerCloseHookFunc) {
- dialerCloseHooksMutex.Lock()
- defer dialerCloseHooksMutex.Unlock()
- dialerCloseHooks = append(dialerCloseHooks, hook)
-}
-
-// RemoveDialerHooks removes all dialer hooks.
-func RemoveDialerHooks() {
- dialerDialHooksMutex.Lock()
- defer dialerDialHooksMutex.Unlock()
- dialerDialHooks = nil
-
- dialerCloseHooksMutex.Lock()
- defer dialerCloseHooksMutex.Unlock()
- dialerCloseHooks = nil
-}
-
-// DialContext wraps the net.Dialer's DialContext method to use the custom connection
-func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
- log.Debugf("Dialing %s %s", network, address)
-
- if CustomRoutingDisabled() {
- return d.Dialer.DialContext(ctx, network, address)
- }
-
- var resolver *net.Resolver
- if d.Resolver != nil {
- resolver = d.Resolver
- }
-
- connID := GenerateConnID()
- if dialerDialHooks != nil {
- if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
- log.Errorf("Failed to call dialer hooks: %v", err)
- }
- }
-
- conn, err := d.Dialer.DialContext(ctx, network, address)
- if err != nil {
- return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
- }
-
- // Wrap the connection in Conn to handle Close with hooks
- return &Conn{Conn: conn, ID: connID}, nil
-}
-
-// Dial wraps the net.Dialer's Dial method to use the custom connection
-func (d *Dialer) Dial(network, address string) (net.Conn, error) {
- return d.DialContext(context.Background(), network, address)
-}
-
-func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
- host, _, err := net.SplitHostPort(address)
- if err != nil {
- return fmt.Errorf("split host and port: %w", err)
- }
- ips, err := resolver.LookupIPAddr(ctx, host)
- if err != nil {
- return fmt.Errorf("failed to resolve address %s: %w", address, err)
- }
-
- log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
-
- var result *multierror.Error
-
- dialerDialHooksMutex.RLock()
- defer dialerDialHooksMutex.RUnlock()
- for _, hook := range dialerDialHooks {
- if err := hook(ctx, connID, ips); err != nil {
- result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
- }
- }
-
- return result.ErrorOrNil()
-}
diff --git a/util/net/dialer_init_nonlinux.go b/util/net/dialer_init_nonlinux.go
deleted file mode 100644
index 8c57ebbaa..000000000
--- a/util/net/dialer_init_nonlinux.go
+++ /dev/null
@@ -1,7 +0,0 @@
-//go:build !linux
-
-package net
-
-func (d *Dialer) init() {
- // implemented on Linux and Android only
-}
diff --git a/util/net/env_generic.go b/util/net/env_generic.go
deleted file mode 100644
index 6d142a838..000000000
--- a/util/net/env_generic.go
+++ /dev/null
@@ -1,12 +0,0 @@
-//go:build !linux || android
-
-package net
-
-func Init() {
- // nothing to do on non-linux
-}
-
-func AdvancedRouting() bool {
- // non-linux currently doesn't support advanced routing
- return false
-}
diff --git a/util/net/listen.go b/util/net/listen.go
deleted file mode 100644
index 3ae8a9435..000000000
--- a/util/net/listen.go
+++ /dev/null
@@ -1,37 +0,0 @@
-//go:build !ios
-
-package net
-
-import (
- "context"
- "fmt"
- "net"
- "sync"
-
- "github.com/pion/transport/v3"
- log "github.com/sirupsen/logrus"
-)
-
-// ListenUDP listens on the network address and returns a transport.UDPConn
-// which includes support for write and close hooks.
-func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
- if CustomRoutingDisabled() {
- return net.ListenUDP(network, laddr)
- }
-
- conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
- if err != nil {
- return nil, fmt.Errorf("listen UDP: %w", err)
- }
-
- packetConn := conn.(*PacketConn)
- udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
- if !ok {
- if err := packetConn.Close(); err != nil {
- log.Errorf("Failed to close connection: %v", err)
- }
- return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
- }
-
- return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
-}
diff --git a/util/net/listener_init_nonlinux.go b/util/net/listener_init_nonlinux.go
deleted file mode 100644
index 80f6f7f1a..000000000
--- a/util/net/listener_init_nonlinux.go
+++ /dev/null
@@ -1,7 +0,0 @@
-//go:build !linux
-
-package net
-
-func (l *ListenerConfig) init() {
- // implemented on Linux and Android only
-}
diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go
deleted file mode 100644
index 4060ab49a..000000000
--- a/util/net/listener_listen.go
+++ /dev/null
@@ -1,205 +0,0 @@
-//go:build !ios
-
-package net
-
-import (
- "context"
- "fmt"
- "net"
- "net/netip"
- "sync"
-
- log "github.com/sirupsen/logrus"
-)
-
-// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn.
-type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error
-
-// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
-type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
-
-// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed.
-type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
-
-var (
- listenerWriteHooksMutex sync.RWMutex
- listenerWriteHooks []ListenerWriteHookFunc
- listenerCloseHooksMutex sync.RWMutex
- listenerCloseHooks []ListenerCloseHookFunc
- listenerAddressRemoveHooksMutex sync.RWMutex
- listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc
-)
-
-// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
-func AddListenerWriteHook(hook ListenerWriteHookFunc) {
- listenerWriteHooksMutex.Lock()
- defer listenerWriteHooksMutex.Unlock()
- listenerWriteHooks = append(listenerWriteHooks, hook)
-}
-
-// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection.
-func AddListenerCloseHook(hook ListenerCloseHookFunc) {
- listenerCloseHooksMutex.Lock()
- defer listenerCloseHooksMutex.Unlock()
- listenerCloseHooks = append(listenerCloseHooks, hook)
-}
-
-// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed.
-func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) {
- listenerAddressRemoveHooksMutex.Lock()
- defer listenerAddressRemoveHooksMutex.Unlock()
- listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook)
-}
-
-// RemoveListenerHooks removes all listener hooks.
-func RemoveListenerHooks() {
- listenerWriteHooksMutex.Lock()
- defer listenerWriteHooksMutex.Unlock()
- listenerWriteHooks = nil
-
- listenerCloseHooksMutex.Lock()
- defer listenerCloseHooksMutex.Unlock()
- listenerCloseHooks = nil
-
- listenerAddressRemoveHooksMutex.Lock()
- defer listenerAddressRemoveHooksMutex.Unlock()
- listenerAddressRemoveHooks = nil
-}
-
-// ListenPacket listens on the network address and returns a PacketConn
-// which includes support for write hooks.
-func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
- if CustomRoutingDisabled() {
- return l.ListenConfig.ListenPacket(ctx, network, address)
- }
-
- pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
- if err != nil {
- return nil, fmt.Errorf("listen packet: %w", err)
- }
- connID := GenerateConnID()
-
- return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
-}
-
-// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
-type PacketConn struct {
- net.PacketConn
- ID ConnectionID
- seenAddrs *sync.Map
-}
-
-// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
-func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
- callWriteHooks(c.ID, c.seenAddrs, b, addr)
- return c.PacketConn.WriteTo(b, addr)
-}
-
-// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
-func (c *PacketConn) Close() error {
- c.seenAddrs = &sync.Map{}
- return closeConn(c.ID, c.PacketConn)
-}
-
-// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
-type UDPConn struct {
- *net.UDPConn
- ID ConnectionID
- seenAddrs *sync.Map
-}
-
-// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
-func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
- callWriteHooks(c.ID, c.seenAddrs, b, addr)
- return c.UDPConn.WriteTo(b, addr)
-}
-
-// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
-func (c *UDPConn) Close() error {
- c.seenAddrs = &sync.Map{}
- return closeConn(c.ID, c.UDPConn)
-}
-
-// RemoveAddress removes an address from the seen cache and triggers removal hooks.
-func (c *PacketConn) RemoveAddress(addr string) {
- if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
- return
- }
-
- ipStr, _, err := net.SplitHostPort(addr)
- if err != nil {
- log.Errorf("Error splitting IP address and port: %v", err)
- return
- }
-
- ipAddr, err := netip.ParseAddr(ipStr)
- if err != nil {
- log.Errorf("Error parsing IP address %s: %v", ipStr, err)
- return
- }
-
- prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
-
- listenerAddressRemoveHooksMutex.RLock()
- defer listenerAddressRemoveHooksMutex.RUnlock()
-
- for _, hook := range listenerAddressRemoveHooks {
- if err := hook(c.ID, prefix); err != nil {
- log.Errorf("Error executing listener address remove hook: %v", err)
- }
- }
-}
-
-
-// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality
-func WrapPacketConn(conn net.PacketConn) *PacketConn {
- return &PacketConn{
- PacketConn: conn,
- ID: GenerateConnID(),
- seenAddrs: &sync.Map{},
- }
-}
-
-func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
- // Lookup the address in the seenAddrs map to avoid calling the hooks for every write
- if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
- ipStr, _, splitErr := net.SplitHostPort(addr.String())
- if splitErr != nil {
- log.Errorf("Error splitting IP address and port: %v", splitErr)
- return
- }
-
- ip, err := net.ResolveIPAddr("ip", ipStr)
- if err != nil {
- log.Errorf("Error resolving IP address: %v", err)
- return
- }
- log.Debugf("Listener resolved IP for %s: %s", addr, ip)
-
- func() {
- listenerWriteHooksMutex.RLock()
- defer listenerWriteHooksMutex.RUnlock()
-
- for _, hook := range listenerWriteHooks {
- if err := hook(id, ip, b); err != nil {
- log.Errorf("Error executing listener write hook: %v", err)
- }
- }
- }()
- }
-}
-
-func closeConn(id ConnectionID, conn net.PacketConn) error {
- err := conn.Close()
-
- listenerCloseHooksMutex.RLock()
- defer listenerCloseHooksMutex.RUnlock()
-
- for _, hook := range listenerCloseHooks {
- if err := hook(id, conn); err != nil {
- log.Errorf("Error executing listener close hook: %v", err)
- }
- }
-
- return err
-}