Compare commits

..

11 Commits

Author SHA1 Message Date
Viktor Liu
58c79f5878 [client] Fix DNS custom zone teardown: handler leak and external CNAME resolution (#6445) 2026-06-19 17:33:09 +02:00
Viktor Liu
15a0504fb1 [client] Treat answering upstreams as reachable and widen DNS health grace window (#6453) 2026-06-19 17:32:49 +02:00
Riccardo Manfrin
883a1a8961 [client] Fix profile regressions in up --profile and status (#6479)
* Restores behavior to create profile if not there on Up

* Allows to restore nerbird status showing of the profile name

* [client] Reduce upFunc cognitive complexity

Extract the profile switch/auto-create logic from upFunc into a dedicated
switchOrCreateProfile helper. The inlined NotFound-retry branch pushed
upFunc over SonarCloud's cognitive complexity threshold (S3776).
No behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* [client] Make up --profile auto-create idempotent under concurrent runs

Don't fail switchOrCreateProfile on a createProfile error: a concurrent
run may create the profile between the NotFound check and our create
call. Retry the switch regardless and only surface the create error if
the switch also fails. Addresses CodeRabbit race-condition feedback.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* Share createProfile with addProfileFunc

* But allow conn reusage

* moves switchOrCreateProfile to where it's used

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-19 16:23:51 +02:00
Maycon Santos
54192a94b7 [misc] handle release candidates when fetching tags in FreeBSD port scripts (#6480)
* [misc] Exclude release candidates when fetching tags in FreeBSD port scripts
2026-06-19 14:10:43 +02:00
Pascal Fischer
8511687270 [management] log peer meta diff (#6468) 2026-06-19 13:30:52 +02:00
Pascal Fischer
35b465fa4a [management] reduce sync and login transaction (#6472) 2026-06-19 11:43:01 +02:00
Brad Ison
fb87f751a5 [management] Fetch complete user data in ValidateTunnelPeer (#6457)
* [management] Fetch complete user data in ValidateTunnelPeer

Previously the `ValidateTunnelPeer` method used by the ProxyService
would fetch user information from the database if the connected peer
was associated with a user ID, but it would not consult the IdP data
for cached info from JWT claims like email.  This caused the value of
the injected `X-Netbird-User` header to always display the peer ID and
never the user email associated with the peer as expected.

This change adds an optional IdP manager to the ProxyService and
fetches the complete user data from it if present.

* [management] Refactor ValidateTunnelPeer principal info gathering

This refactors the gathering of info on proxy tunnel peer principals
into its own method to keep the complexity down and make Sonar happy.
2026-06-19 11:39:21 +02:00
Maycon Santos
679c7182a4 [misc] Remove version prefix v docker tags (#6471) 2026-06-18 22:34:24 +02:00
Pascal Fischer
8c031ea6f0 [management] remove db calls in nested loops (#6470) 2026-06-18 22:12:59 +02:00
Pascal Fischer
60a9544656 [management] pass meta update for browser clients (#6465) 2026-06-18 17:22:42 +02:00
Viktor Liu
d3710d4bb2 [signal] Serialize concurrent sends to a peer signal stream (#6463) 2026-06-18 15:00:19 +02:00
37 changed files with 1142 additions and 1197 deletions

View File

@@ -247,7 +247,7 @@ dockers_v2:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile
extra_files:
@@ -295,7 +295,7 @@ dockers_v2:
- netbirdio/relay
- ghcr.io/netbirdio/relay
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile
platforms:
@@ -317,7 +317,7 @@ dockers_v2:
- netbirdio/signal
- ghcr.io/netbirdio/signal
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile
platforms:
@@ -339,7 +339,7 @@ dockers_v2:
- netbirdio/management
- ghcr.io/netbirdio/management
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile
platforms:
@@ -361,7 +361,7 @@ dockers_v2:
- netbirdio/upload
- ghcr.io/netbirdio/upload
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile
platforms:
@@ -383,7 +383,7 @@ dockers_v2:
- netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile
platforms:
@@ -405,7 +405,7 @@ dockers_v2:
- netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile
platforms:

View File

@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
Username: &username,
})
if err != nil {
return "", fmt.Errorf("switch profile failed: %v", err)
return "", fmt.Errorf("switch profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil

View File

@@ -138,26 +138,23 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
return err
}
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
if err != nil {
return fmt.Errorf("add profile request: %w", err)
return err
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
@@ -166,7 +163,6 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
@@ -330,3 +326,19 @@ func wrapAmbiguityError(err error, handle string) error {
}
return err
}
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
// and returns the new profile's ID. It is the single entry point for profile
// creation, shared by `netbird profile add` and the `netbird up --profile
// <name>` auto-create path.
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
ProfileName: profileName,
Username: username,
})
if err != nil {
return "", fmt.Errorf("add profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil
}

View File

@@ -11,7 +11,6 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
@@ -111,11 +110,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
// Resolve the active profile's display name via the daemon, which runs
// as root and can read the per-user profile files. The local profile
// manager only knows the active profile ID, not its display name.
profName := getActiveProfileName(ctx)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
@@ -167,6 +165,25 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
return resp, nil
}
// getActiveProfileName asks the daemon for the active profile's display
// name. The daemon runs as root and can read the per-user profile files to
// resolve the ID to its human-readable name. Returns an empty string on any
// error so status output degrades gracefully.
func getActiveProfileName(ctx context.Context) string {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return ""
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
if err != nil {
return ""
}
return resp.GetProfileName()
}
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected":

View File

@@ -128,15 +128,9 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
profileSwitched = true
}
@@ -151,6 +145,52 @@ func upFunc(cmd *cobra.Command, args []string) error {
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
}
// switchOrCreateProfile switches the active profile to the one identified by
// handle, creating it first when it does not exist yet. This restores the
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
// missing profile instead of failing.
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
st, ok := gstatus.FromError(err)
if !ok || st.Code() != codes.NotFound {
return err
}
// Don't fail immediately on a create error: a concurrent run may
// have created the profile between the NotFound above and this
// call, in which case the retried switch still succeeds. Only
// surface the create error if the switch also fails.
_, createErr := createProfile(ctx, handle, username)
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
if createErr != nil {
return fmt.Errorf("create profile: %w", createErr)
}
return err
}
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return err
}
return nil
}
// createProfile dials the daemon and creates a new profile with the given
// display name, returning its generated ID. Use addProfileOnDaemon directly
// when a daemon client is already available to reuse the connection.
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
}
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
// override the default profile filepath if provided
if configPath != "" {

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
@@ -38,11 +39,15 @@ const (
// defaultWarningDelayBase is the starting grace window before a
// "Nameserver group unreachable" event fires for a group that's
// never been healthy and only has overlay upstreams with no
// Connected peer. Per-server and overridable; see warningDelayFor.
defaultWarningDelayBase = 30 * time.Second
// Connected peer. Per-server and overridable via envWarningDelay;
// see warningDelay.
defaultWarningDelayBase = 60 * time.Second
// warningDelayBonusCap caps the route-count bonus added to the
// base grace window. See warningDelayFor.
// base grace window. See warningDelay.
warningDelayBonusCap = 30 * time.Second
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
)
// errNoUsableNameservers signals that a merged-domain group has no usable
@@ -135,7 +140,7 @@ type DefaultServer struct {
disableSys bool
mux sync.Mutex
service service
dnsMuxMap registeredHandlerMap
dnsMuxHandlers []handlerWrapper
localResolver *local.Resolver
wgInterface WGIface
hostManager hostManager
@@ -199,8 +204,6 @@ type handlerWrapper struct {
priority int
}
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
WgInterface WGIface
@@ -289,7 +292,6 @@ func newDefaultServer(
service: dnsService,
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(),
wgInterface: wgInterface,
statusRecorder: statusRecorder,
@@ -298,7 +300,7 @@ func newDefaultServer(
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
warningDelayBase: defaultWarningDelayBase,
warningDelayBase: warningDelayBaseFromEnv(),
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
@@ -328,7 +330,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
type routeSettable interface {
setSelectedRoutes(func() route.HAMap)
}
for _, entry := range s.dnsMuxMap {
for _, entry := range s.dnsMuxHandlers {
if h, ok := entry.handler.(routeSettable); ok {
h.setSelectedRoutes(selected)
}
@@ -978,19 +980,23 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap {
for _, existing := range s.dnsMuxHandlers {
s.deregisterHandler([]string{existing.domain}, existing.priority)
existing.handler.Stop()
// The local resolver is a persistent singleton shared by every custom
// zone and reused across config updates. Its chain registrations are
// per-config and must be deregistered, but Stop() cancels its lookup
// context (breaking external CNAME-target resolution) and clears its
// records, so it must not be torn down here.
if existing.handler != s.localResolver {
existing.handler.Stop()
}
}
muxUpdateMap := make(registeredHandlerMap)
for _, update := range muxUpdates {
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update
}
s.dnsMuxMap = muxUpdateMap
s.dnsMuxHandlers = muxUpdates
}
// updateNSGroupStates records the new group set and pokes the refresher.
@@ -1154,6 +1160,26 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
return false
}
// warningDelayBaseFromEnv returns the base grace window, honoring
// envWarningDelay when it holds a valid positive Go duration. Invalid or
// non-positive values fall back to defaultWarningDelayBase.
func warningDelayBaseFromEnv() time.Duration {
val := os.Getenv(envWarningDelay)
if val == "" {
return defaultWarningDelayBase
}
d, err := time.ParseDuration(val)
if err != nil {
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
return defaultWarningDelayBase
}
if d <= 0 {
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
return defaultWarningDelayBase
}
return d
}
// warningDelay returns the grace window for the given selected-route
// count. Scales gently: +1s per 100 routes, capped by
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
@@ -1204,7 +1230,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
// in more than one handler.
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
merged := make(map[netip.AddrPort]UpstreamHealth)
for _, entry := range s.dnsMuxMap {
for _, entry := range s.dnsMuxHandlers {
reporter, ok := entry.handler.(upstreamHealthReporter)
if !ok {
continue

View File

@@ -104,19 +104,6 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
@@ -132,22 +119,20 @@ func TestUpdateDNSServer(t *testing.T) {
},
}
dummyHandler := local.NewResolver()
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registeredHandlerMap
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -169,20 +154,17 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
dummyHandler.ID(): handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
@@ -191,10 +173,10 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
@@ -215,15 +197,13 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
"local-resolver": handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
},
@@ -232,7 +212,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
@@ -240,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -262,7 +242,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -284,7 +264,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -301,42 +281,41 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap),
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap),
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
@@ -393,7 +372,7 @@ func TestUpdateDNSServer(t *testing.T) {
}
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
@@ -405,14 +384,20 @@ func TestUpdateDNSServer(t *testing.T) {
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
@@ -512,8 +497,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
}()
dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
@@ -1029,15 +1014,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {}
func TestDefaultServer_UpdateMux(t *testing.T) {
baseMatchHandlers := registeredHandlerMap{
"upstream-group1": {
baseMatchHandlers := []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
},
"upstream-group2": {
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
@@ -1046,15 +1031,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
},
}
baseRootHandlers := registeredHandlerMap{
"upstream-root1": {
baseRootHandlers := []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
"upstream-root2": {
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
@@ -1063,22 +1048,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
},
}
baseMixedHandlers := registeredHandlerMap{
"upstream-group1": {
baseMixedHandlers := []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
},
"upstream-group2": {
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
},
"upstream-other": {
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
@@ -1089,7 +1074,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
tests := []struct {
name string
initialHandlers registeredHandlerMap
initialHandlers []handlerWrapper
updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain
description string
@@ -1373,32 +1358,38 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &DefaultServer{
dnsMuxMap: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
dnsMuxHandlers: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Perform the update
server.updateMux(tt.updates)
// Verify the results
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
"Number of handlers after update doesn't match expected")
// Check each expected handler
for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id)
if exists {
assert.Equal(t, expectedDomain, handler.domain,
var found *handlerWrapper
for i := range server.dnsMuxHandlers {
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
found = &server.dnsMuxHandlers[i]
break
}
}
assert.NotNil(t, found, "Expected handler %s not found", id)
if found != nil {
assert.Equal(t, expectedDomain, found.domain,
"Domain mismatch for handler %s", id)
}
}
// Verify no unexpected handlers exist
for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
for _, entry := range server.dnsMuxHandlers {
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
}
// Verify the handlerChain state and order
@@ -1413,7 +1404,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Verify handler exists in mux
foundInMux := false
for _, muxEntry := range server.dnsMuxMap {
for _, muxEntry := range server.dnsMuxHandlers {
if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
@@ -1422,12 +1413,108 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}
}
assert.True(t, foundInMux,
"Handler in chain not found in dnsMuxMap")
"Handler in chain not found in dnsMuxHandlers")
}
})
}
}
// chainHasPattern reports whether the handler chain holds an entry registered
// for the given fqdn pattern at the given priority.
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
for _, h := range s.handlerChain.handlers {
if h.OrigPattern == pattern && h.Priority == priority {
return true
}
}
return false
}
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
// tracks each (handler, domain) registration independently when one handler
// serves multiple zones. Every custom zone is served by the same handler
// instance (the local resolver, whose ID is the constant "local-resolver"), so
// removing one zone must deregister exactly that zone's chain entry and leave
// the others in place. Tracking registrations by handler ID alone collapses all
// zones onto one entry, leaving removed zones in the chain to answer
// authoritatively with no records.
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
// One handler serves every custom zone, mirroring s.localResolver.
shared := &mockHandler{Id: "local-resolver"}
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Two custom zones under the same handler. The surviving zone is registered
// last, mirroring the management emission order.
server.updateMux([]handlerWrapper{
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test should be registered after the first update")
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should be registered after the first update")
// Remove one zone, keep the other.
server.updateMux([]handlerWrapper{
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should remain after removing userzone.test")
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test handler must be deregistered, not leaked in the chain")
}
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
// does not tear down the shared local resolver during reconfiguration. The
// resolver is a process-lifetime singleton reused across config updates;
// Stop() cancels its lookup context (breaking external CNAME-target
// resolution) and clears its records. updateMux must deregister its chain
// entries without stopping it. Records surviving a teardown update is the
// observable proxy: Stop() would have cleared them.
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
resolver := local.NewResolver()
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
Name: "peer.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}))
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
localResolver: resolver,
}
server.updateMux([]handlerWrapper{
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
})
// Remove the zone. The resolver must survive so its records and lookup
// context stay intact for the next registration.
server.updateMux(nil)
var response *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
response = m
return nil
},
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
require.NotNil(t, response, "local resolver should answer after teardown")
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
}
func TestExtraDomains(t *testing.T) {
tests := []struct {
name string
@@ -2049,7 +2136,6 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
@@ -2207,7 +2293,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
}
}
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
@@ -2283,12 +2369,11 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
@@ -2395,7 +2480,6 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
@@ -2407,7 +2491,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2444,7 +2528,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
@@ -2459,7 +2542,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2484,6 +2567,32 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestWarningDelayBaseFromEnv(t *testing.T) {
tests := []struct {
name string
set bool
val string
want time.Duration
}{
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envWarningDelay, tc.val)
if !tc.set {
os.Unsetenv(envWarningDelay)
}
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
})
}
}
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
@@ -2595,7 +2704,6 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
@@ -2613,7 +2721,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2640,7 +2748,6 @@ func TestDNSLoopPrevention(t *testing.T) {
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
tests := []struct {

View File

@@ -443,21 +443,25 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
// A valid response means the upstream is reachable, whatever the Rcode.
u.markUpstreamOk(upstream)
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
// refused zones, transient recursion errors), not reachability
// problems: fail over for a better answer but keep the upstream healthy.
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
@@ -465,7 +469,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}

View File

@@ -517,6 +517,78 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
// those are per-question outcomes from a reachable server and must not mark
// the upstream unhealthy. Only transport failures (timeouts) do.
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.10:53")
b := netip.MustParseAddrPort("192.0.2.11:53")
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
tests := []struct {
name string
respA mockUpstreamResponse
respB mockUpstreamResponse
wantHealthy bool
}{
{
name: "both SERVFAIL are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
wantHealthy: true,
},
{
name: "both REFUSED are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
wantHealthy: true,
},
{
name: "timeout marks unhealthy",
respA: mockUpstreamResponse{err: timeoutErr},
respB: mockUpstreamResponse{err: timeoutErr},
wantHealthy: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): tc.respA,
b.String(): tc.respB,
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a, b})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, a, "primary upstream should have a health record")
if tc.wantHealthy {
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
} else {
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
}
})
}
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string

View File

@@ -64,6 +64,7 @@ import (
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -1077,17 +1078,11 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return ErrResetConnection
}
if !e.config.DisableIPv6 {
reset, err := e.reconcileIPv6(conf)
if err != nil {
log.Warnf("reconcile IPv6 from PeerConfig: %v", err)
}
if reset {
log.Infof("peer IPv6 address changed value, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
}
if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) {
log.Infof("peer IPv6 address changed, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
}
if conf.GetSshConfig() != nil {
@@ -1109,58 +1104,25 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
// reconcileIPv6 applies the management-supplied IPv6 overlay address to the
// engine's WireGuard interface in place when possible. Three transitions:
//
// - First v6 assignment (current had no v6, conf carries one): apply via
// WGIface.UpdateAddr, no reset. Critical for embedded clients whose
// boot config has no v6 — without this we reset on every fresh start
// once management has v6 enabled, orphaning any netstack listeners
// held outside the engine.
// - v6 removed (current had v6, conf carries none): clear in place, no
// reset.
// - v6 swapped to a different non-empty value: returns reset=true so the
// caller falls back to the engine-recreate path — the underlying
// interface address can't be safely swapped in place across all
// backends (gVisor netstack in particular fixes its address at
// CreateNetTUN time).
//
// Mutates e.config.WgAddr to match the applied state so subsequent
// PeerConfig comparisons are stable.
func (e *Engine) reconcileIPv6(conf *mgmProto.PeerConfig) (reset bool, err error) {
raw := conf.GetAddressV6()
// hasIPv6Changed reports whether the IPv6 overlay address in the peer config
// differs from the configured address (added, removed, or changed).
// Compares against e.config.WgAddr (not the interface address, which may have
// been cleared by ClearIPv6 if OS assignment failed).
func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
current := e.config.WgAddr
raw := conf.GetAddressV6()
if len(raw) == 0 {
if !current.HasIPv6() {
return false, nil
}
current.ClearIPv6()
e.config.WgAddr = current
if err := e.wgInterface.UpdateAddr(current); err != nil {
return false, fmt.Errorf("clear ipv6 on wg interface: %w", err)
}
return false, nil
return current.HasIPv6()
}
incoming := current
if err := incoming.SetIPv6FromCompact(raw); err != nil {
return false, fmt.Errorf("decode v6 overlay address: %w", err)
prefix, err := netiputil.DecodePrefix(raw)
if err != nil {
log.Errorf("decode v6 overlay address: %v", err)
return false
}
if !current.HasIPv6() {
e.config.WgAddr = incoming
if err := e.wgInterface.UpdateAddr(incoming); err != nil {
return false, fmt.Errorf("apply ipv6 on wg interface: %w", err)
}
return false, nil
}
if current.IPv6 == incoming.IPv6 && current.IPv6Net == incoming.IPv6Net {
return false, nil
}
return true, nil
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
}
func (e *Engine) receiveJobEvents() {

View File

@@ -1,305 +0,0 @@
package internal
import (
"context"
"errors"
"net/netip"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/peer"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// reconcileIPv6 / updateConfig regression suite. Locks down the behavior that
// PR #5631 (main-side IPv6 overlay support) accidentally broke for embedded
// netstack clients: any first NetworkMap update that brings an IPv6 address
// used to trigger ErrResetConnection, which destroys the netstack and orphans
// every listener bound on it (proxy-side inbound listeners in particular).
// The fix in reconcileIPv6 distinguishes "v6 first-assigned" (apply in place)
// from "v6 swapped value" (must reset).
func mustEncodeV6Prefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err, "encode v6 prefix %s", p)
return b
}
// reconcileIPv6Fixture builds the smallest Engine the function under test
// needs: a config (with WgAddr being the load-bearing field) and a wgInterface
// whose UpdateAddr call we can observe.
func reconcileIPv6Fixture(t *testing.T, initial wgaddr.Address) (*Engine, *MockWGIface, *wgaddr.Address) {
t.Helper()
var applied wgaddr.Address
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return initial },
UpdateAddrFunc: func(a wgaddr.Address) error {
applied = a
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: initial},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
return e, mock, &applied
}
func TestReconcileIPv6_FirstAssignment_AppliesInPlace(t *testing.T) {
// Embedded clients boot v4-only; management later assigns a v6 overlay.
// The fix: apply v6 in place, return reset=false. Pre-fix this case
// fell through to the "v6 changed" branch and reset the engine.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, mock, applied := reconcileIPv6Fixture(t, v4)
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "first v6 assignment must NOT request an engine reset")
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the new v6")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6, "engine config v6 address must match")
assert.Equal(t, v6Prefix.Masked(), e.config.WgAddr.IPv6Net, "engine config v6 prefix must match")
require.True(t, applied.HasIPv6(), "WGIface.UpdateAddr must be called with v6 populated")
assert.Equal(t, v6Prefix.Addr(), applied.IPv6, "UpdateAddr must carry the new v6")
_ = mock
}
func TestReconcileIPv6_NoChange_NoOp(t *testing.T) {
// Steady state: management redelivers the same PeerConfig. No interface
// mutation, no reset. Guards against an infinite reset loop if the
// comparison ever drifts (e.g. address-vs-prefix masking bugs).
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "unchanged v6 must NOT trigger reset")
assert.False(t, updateAddrCalled, "unchanged v6 must NOT call UpdateAddr")
}
func TestReconcileIPv6_Removed_AppliesInPlace(t *testing.T) {
// Management withdraws v6 (e.g. account toggled off the v6 group).
// Cleared in place, no reset.
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
e, _, applied := reconcileIPv6Fixture(t, addr)
e.config.WgAddr = addr
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: nil,
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "v6 removed must NOT trigger reset")
assert.False(t, e.config.WgAddr.HasIPv6(), "engine config must reflect v6 cleared")
assert.False(t, applied.HasIPv6(), "UpdateAddr must receive cleared v6")
}
func TestReconcileIPv6_PrefixLengthChanged_RequestsReset(t *testing.T) {
// Same v6 host, different mask (e.g. /64 → /80). Treated like a value
// change because the new netmask redefines the broadcast/scope.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::1/80")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 prefix length change must request a reset")
assert.False(t, updateAddrCalled, "v6 prefix length change must NOT touch the interface")
}
func TestReconcileIPv6_ValueChanged_RequestsReset(t *testing.T) {
// v6 was X, now Y. The netstack backend can't safely swap an existing
// address in place — fall back to the engine recreate path.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::2/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 value change must request a reset")
assert.False(t, updateAddrCalled,
"v6 value change must NOT call UpdateAddr — caller will recreate the interface")
}
func TestReconcileIPv6_InvalidBytes_ReturnsError(t *testing.T) {
// Corrupt PeerConfig.AddressV6 must not crash the engine and must not
// trigger a spurious reset.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, _, applied := reconcileIPv6Fixture(t, v4)
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: []byte{0x00}, // truncated, definitely not a valid prefix
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "malformed v6 bytes must surface an error")
assert.False(t, reset, "decode error must NOT request a reset")
assert.False(t, applied.HasIPv6(), "decode error must NOT touch the interface")
}
func TestReconcileIPv6_UpdateAddrError_DoesNotPropagateReset(t *testing.T) {
// If WGIface.UpdateAddr fails (e.g. OS-side assignment error on a
// kernel device), reconcileIPv6 returns the error to the caller for
// logging — but it must NOT request a reset. The whole point of the
// fix is to AVOID the reset cascade on v6 transitions.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return errors.New("os refused address") },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "UpdateAddr failure must surface")
assert.False(t, reset, "UpdateAddr failure must NOT request a reset")
}
func TestUpdateConfig_V6FirstAssignment_DoesNotResetEngine(t *testing.T) {
// The integration check: updateConfig must not return ErrResetConnection
// when the only change between current state and the new PeerConfig is
// "v6 added". Pre-fix this returned ErrResetConnection, tearing down
// every listener bound on the engine's netstack.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return nil },
IsUserspaceBindFunc: func() bool { return true },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4, WgPort: 51820},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
statusRecorder: peer.NewRecorder("https://mgm.test"),
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
err := e.updateConfig(conf)
assert.NoError(t, err,
"updateConfig MUST NOT return ErrResetConnection when v6 is added for the first time — that's the bug fix")
assert.NotErrorIs(t, err, ErrResetConnection)
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the assigned v6 after updateConfig")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6)
}

View File

@@ -66,6 +66,7 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -1706,12 +1707,82 @@ func getPeers(e *Engine) int {
return len(e.peerStore.PeersPubKey())
}
// The former TestEngine_hasIPv6Changed has been superseded by
// engine_reconcileipv6_test.go — the underlying function (hasIPv6Changed)
// was replaced by reconcileIPv6, which applies "v6 added" / "v6 removed"
// in place instead of demanding a full engine reset. The behavioral
// matrix the old test enforced is now covered, with corrected expectations,
// by TestReconcileIPv6_* in that sibling file.
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err)
return b
}
func TestEngine_hasIPv6Changed(t *testing.T) {
v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6.IPv6 = netip.MustParseAddr("fd00::1")
v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
tests := []struct {
name string
current wgaddr.Address
confV6 []byte
expected bool
}{
{
name: "no v6 before, no v6 now",
current: v4Only,
confV6: nil,
expected: false,
},
{
name: "no v6 before, v6 added",
current: v4Only,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: true,
},
{
name: "had v6, now removed",
current: v4v6,
confV6: nil,
expected: true,
},
{
name: "had v6, same v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: false,
},
{
name: "had v6, different v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")),
expected: true,
},
{
name: "same v6 addr, different prefix length",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")),
expected: true,
},
{
name: "decode error keeps status quo",
current: v4Only,
confV6: []byte{1, 2, 3},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{WgAddr: tt.current},
}
conf := &mgmtProto.PeerConfig{
AddressV6: tt.confV6,
}
assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf))
})
}
}
func TestFilterAllowedIPs(t *testing.T) {
v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16")

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
return srv
}
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -33,6 +33,8 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
proxyauth "github.com/netbirdio/netbird/proxy/auth"
@@ -82,6 +84,9 @@ type ProxyServiceServer struct {
// Manager for users
usersManager users.Manager
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
idpManager idp.Manager
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
@@ -157,7 +162,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -166,6 +171,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
idpManager: idpManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
@@ -1702,22 +1708,7 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
@@ -1754,6 +1745,45 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}, nil
}
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
// user or peer ID, and peer name or user email.
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
// Resolve the principal: when the peer is linked to a user, the human is the
// principal so multiple peers owned by the same user share a single
// identity. Unlinked peers (machine agents) are their own principal keyed on
// peer.ID. displayIdentity is what upstream gateways tag spend with —
// user.Email when linked, peer.Name when not.
// If the peer isn't associated with a user, return the peer info directly.
if peer.UserID == "" {
return peer.ID, peer.Name
}
// Otherwise, if the peer is linked to a user, the user is the principal and
// if an IdP is available, we gather details on the user from it.
principalID := peer.UserID
displayIdentity := peer.Name
// Stored column first (cheap, but often empty for OIDC-provisioned users).
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
// IdP enrichment wins when available — the stored email column is a
// best-effort cache and is frequently empty for OIDC users. Enrichment
// failures must never fail the RPC; we simply keep the stored/peer identity.
if s.idpManager != nil {
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
displayIdentity = ud.Email
} else if uerr != nil {
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
}
}
return principalID, displayIdentity
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security

View File

@@ -3,14 +3,19 @@ package grpc
import (
"context"
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockReverseProxyManager struct {
@@ -137,6 +142,52 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
return user, nil, nil
}
// mockTunnelPeersManager implements only the two peers.Manager methods that
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
// panics if any unexpected method is invoked).
type mockTunnelPeersManager struct {
peers.Manager
peer *peer.Peer
peerErr error
groups []*types.Group
groupsErr error
}
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
return m.peer, m.peerErr
}
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
return m.peer, m.groups, m.groupsErr
}
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
// an IdP that knows nothing about the user.
type mockTunnelIdpManager struct {
idp.Manager
email string
hasData bool
err error
gotCalls int
gotMeta []idp.AppMetadata
}
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
m.gotCalls++
m.gotMeta = append(m.gotMeta, meta)
if m.err != nil {
return nil, m.err
}
if !m.hasData {
// This might not be a thing any of the actual IDP implementations do,
// i.e. return a nil value with no error, but it seems valuable to test
// that behavior here.
return nil, nil //nolint:nilnil
}
return &idp.UserData{ID: userID, Email: m.email}, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
@@ -354,6 +405,163 @@ func TestValidateUserGroupAccess(t *testing.T) {
}
}
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
// (IdP email -> stored User.Email -> peer.Name).
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
const (
domain = "app.example.com"
accountID = "account1"
peerID = "peer1"
peerName = "peer-display-name"
userID = "user1"
)
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
tests := []struct {
name string
peerUserID string
storedUsers map[string]*types.User
storedErr error
noIdP bool
idpEmail string
idpHasData bool
idpErr error
expectEmail string
expectUserID string
expectIdPHit bool
}{
{
name: "idp email wins over stored email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp returns empty email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "",
idpHasData: true,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp has no data",
peerUserID: userID,
storedUsers: storedUser,
idpHasData: false,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp errors",
peerUserID: userID,
storedUsers: storedUser,
idpErr: errors.New("idp unreachable"),
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when no idp manager",
peerUserID: userID,
storedUsers: storedUser,
noIdP: true,
expectEmail: "stored@example.com",
expectUserID: userID,
},
{
name: "idp email when stored email is empty",
peerUserID: userID,
storedUsers: storedUserNoEmail,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "idp email when stored user missing keeps peer.UserID as principal",
peerUserID: userID,
storedUsers: map[string]*types.User{},
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "unlinked peer uses peer name and never consults idp",
peerUserID: "",
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: peerName,
expectUserID: peerID,
expectIdPHit: false,
},
{
name: "linked peer with empty stored email and no idp falls back to peer name",
peerUserID: userID,
storedUsers: storedUserNoEmail,
noIdP: true,
expectEmail: peerName,
expectUserID: userID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := &service.Service{Domain: domain, AccountID: accountID}
server := &ProxyServiceServer{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
},
peersManager: &mockTunnelPeersManager{
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
},
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
}
var idpMock *mockTunnelIdpManager
if !tt.noIdP {
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
server.idpManager = idpMock
}
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
Domain: domain,
TunnelIp: "100.64.0.1",
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.True(t, resp.GetValid(), "expected access granted")
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
assert.Equal(t, tt.expectUserID, resp.GetUserId())
if idpMock != nil {
if tt.expectIdPHit {
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
require.Len(t, idpMock.gotMeta, 1)
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
} else {
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
}
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)

View File

@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err

View File

@@ -217,6 +217,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager,
nil,
nil,
nil,
)
proxyService.SetServiceManager(&testServiceManager{store: testStore})

View File

@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {

View File

@@ -982,8 +982,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var peer *nbpeer.Peer
var updated, versionChanged, ipv6CapabilityChanged bool
var err error
var postureChecks []*posture.Checks
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -1011,13 +1009,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return status.NewPeerLoginExpiredError()
}
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
@@ -1025,11 +1018,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err
}
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
}
return nil
})
@@ -1037,6 +1025,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
@@ -1047,9 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1124,7 +1117,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
var peer *nbpeer.Peer
var shouldStorePeer bool
var shouldStorePeer, shouldUpdatePeers bool
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1151,14 +1144,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed {
shouldStorePeer = true
shouldUpdatePeers = true
}
}
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
@@ -1180,7 +1169,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
peer.UpdateMetaIfNew(ctx, login.Meta)
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, false, err
}
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, false, err
}
@@ -1190,7 +1187,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
if isStatusChanged || shouldStorePeer {
if shouldUpdatePeers {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
@@ -1286,12 +1283,22 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, nil, false, nil
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
if err != nil {
return nil, nil, false, err
}
@@ -1299,32 +1306,16 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, postureChecks, enableSSH, nil
}
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
for _, peerID := range peerGroupIDs {
groupIDsMap[peerID] = struct{}{}
}
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return false, err
}
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
for _, g := range peerGroups {
peerGroupIDs[g.ID] = struct{}{}
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
}
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
if len(policies) == 0 {
return nil, nil
}
@@ -1336,11 +1327,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
continue
}
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
if err != nil {
return nil, err
}
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
}
@@ -1353,29 +1340,19 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
}
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil {
return nil, err
}
for _, sourceGroup := range rule.Sources {
group, ok := sourceGroups[sourceGroup]
if !ok {
return nil, fmt.Errorf("failed to check peer in policy source group")
}
if slices.Contains(group.Peers, peerID) {
return policy.SourcePostureChecks, nil
if slices.Contains(peerGroupIDs, sourceGroup) {
return policy.SourcePostureChecks
}
}
}
return nil, nil
return nil
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO

View File

@@ -1,12 +1,16 @@
package peer
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -162,49 +166,7 @@ type PeerSystemMeta struct { //nolint:revive
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
})
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
})
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
sort.Slice(p.Files, func(i, j int) bool {
return p.Files[i].Path < p.Files[j].Path
})
sort.Slice(other.Files, func(i, j int) bool {
return other.Files[i].Path < other.Files[j].Path
})
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.KernelVersion == other.KernelVersion &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.OSVersion == other.OSVersion &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion &&
p.SystemSerialNumber == other.SystemSerialNumber &&
p.SystemProductName == other.SystemProductName &&
p.SystemManufacturer == other.SystemManufacturer &&
p.Environment.Cloud == other.Environment.Cloud &&
p.Environment.Platform == other.Environment.Platform &&
p.Flags.isEqual(other.Flags) &&
capabilitiesEqual(p.Capabilities, other.Capabilities)
return len(metaDiff(p, other)) == 0
}
func (p PeerSystemMeta) isEmpty() bool {
@@ -296,7 +258,7 @@ func (p *Peer) Copy() *Peer {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
if meta.isEmpty() {
return updated, versionChanged
}
@@ -308,14 +270,121 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged boo
meta.UIVersion = p.Meta.UIVersion
}
if p.Meta.isEqual(meta) {
return updated, versionChanged
oldVersion := p.Meta.WtVersion
diff := metaDiff(p.Meta, meta)
if len(diff) != 0 {
p.Meta = meta
updated = true
}
p.Meta = meta
updated = true
versionInfo := ""
if versionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if len(diff) > 0 || versionChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
}
return updated, versionChanged
}
// metaDiff returns a human-readable list of the fields that differ between the
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
// source of truth for meta comparison: isEqual reports equality as an empty
// diff, so the log line can never disagree with the change decision. Slices are
// cloned before sorting, so callers' meta is not mutated.
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
var diff []string
add := func(field string, oldVal, newVal any) {
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
return diff
}
// sameMultiset reports whether two slices contain the same elements with the
// same multiplicity, ignoring order. The element type is the comparison key, so
// every field participates in equality.
func sameMultiset[T comparable](a, b []T) bool {
if len(a) != len(b) {
return false
}
counts := make(map[T]int, len(a))
for _, v := range a {
counts[v]++
}
for _, v := range b {
counts[v]--
if counts[v] == 0 {
delete(counts, v)
}
}
return len(counts) == 0
}
// GetLastLogin returns the last login time of the peer.
func (p *Peer) GetLastLogin() time.Time {
if p.LastLogin != nil {

View File

@@ -0,0 +1,113 @@
package peer
import (
"net/netip"
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
// map 1:1 to a single diff entry. Today the only such field is Environment, which
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
// entry beyond its single struct field. If you teach metaDiff to explode another
// field into N entries, bump this by N-1; if you collapse a field, lower it.
const metaDiffExtraEntries = 1
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
// values and diffs it against the zero value, then asserts metaDiff emits exactly
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
//
// The expected count is derived from the struct via reflection, so adding a field
// to PeerSystemMeta raises the expectation automatically — but the actual diff
// only grows if metaDiff was taught to compare the new field. A mismatch means
// someone changed the struct without updating metaDiff (or this test's
// extra-entry accounting), which is exactly what we want to catch.
func TestMetaDiff_CoversAllFields(t *testing.T) {
var full PeerSystemMeta
exported := populateAll(t, reflect.ValueOf(&full).Elem())
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
diff := metaDiff(PeerSystemMeta{}, full)
require.Len(t, diff, exported+metaDiffExtraEntries,
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
"diff was: %v", diff)
require.False(t, full.isEqual(PeerSystemMeta{}),
"isEqual must report a fully-populated meta as different from the zero value")
}
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
// compare would not change the diff count. This flips each Flags field on its own
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
// fails here.
func TestFlags_isEqualChecksEveryField(t *testing.T) {
typ := reflect.TypeOf(Flags{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
require.Equal(t, reflect.Bool, f.Type.Kind(),
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
var a, b Flags
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
}
}
// populateAll sets every exported field of the struct to a deterministic non-zero
// value, recursing into nested structs and the element type of struct slices so
// that each leaf differs from zero. It returns the number of exported fields on
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
// settable exported fields and is comparable with ==).
func populateAll(t *testing.T, v reflect.Value) int {
t.Helper()
typ := v.Type()
exported := 0
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath != "" { // unexported
continue
}
exported++
setNonZero(t, v.Field(i))
}
return exported
}
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
// recursing into nested structs and populating one element of slice fields.
func setNonZero(t *testing.T, field reflect.Value) {
t.Helper()
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
return
}
switch field.Kind() {
case reflect.String:
field.SetString("non-zero")
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Float32, reflect.Float64:
field.SetFloat(7)
case reflect.Struct:
populateAll(t, field)
case reflect.Slice:
s := reflect.MakeSlice(field.Type(), 1, 1)
setNonZero(t, s.Index(0))
field.Set(s)
default:
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
}
}

View File

@@ -466,20 +466,15 @@ func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.
_ = ln.Close()
}()
var backoff nbtcp.AcceptBackoff
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || nbtcp.IsClosedListenerErr(err) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v; backing off", err)
if !backoff.Backoff(ctx) {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
continue
}
backoff.Reset()
router.HandleConn(ctx, conn)
}
}

View File

@@ -533,125 +533,3 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
// scriptedAcceptListener returns pre-scripted errors from Accept(). Used
// to drive the feedRouterFromListener tests without binding a real
// socket — the production code path is a netstack-backed listener that
// returns gVisor's "endpoint is in invalid state" forever after its
// endpoint is destroyed.
type scriptedAcceptListener struct {
errs chan error
closed chan struct{}
}
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
s := &scriptedAcceptListener{
errs: make(chan error, len(errs)+1),
closed: make(chan struct{}),
}
for _, e := range errs {
s.errs <- e
}
return s
}
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
select {
case <-s.closed:
return nil, net.ErrClosed
case err := <-s.errs:
return nil, err
}
}
func (s *scriptedAcceptListener) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
func (s *scriptedAcceptListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
// errSentinel carries a literal error message so tests can synthesise
// the exact gVisor text without importing the netstack package.
type errSentinel string
func (e errSentinel) Error() string { return string(e) }
// TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint is the
// regression guard for the inbound side of the tight-loop bug. The
// per-account plain-HTTP feeder must recognise gVisor's "endpoint is in
// invalid state" and exit, otherwise it pegs a CPU core and floods the
// account-scoped log with the same accept error every iteration.
func TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
router := nbtcp.NewRouter(logger, nil, addr)
gvisorErr := &net.OpError{
Op: "accept",
Net: "tcp",
Addr: addr,
Err: errSentinel("endpoint is in invalid state"),
}
ln := newScriptedAcceptListener(gvisorErr)
defer ln.Close()
done := make(chan struct{})
go func() {
defer close(done)
feedRouterFromListener(context.Background(), ln, router, logger, "acct-1")
}()
select {
case <-done:
// Expected: loop recognised the gVisor error and returned.
case <-time.After(2 * time.Second):
t.Fatal("feedRouterFromListener did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
}
}
// TestFeedRouterFromListener_BacksOffOnTransientError asserts the
// defence-in-depth path: an unknown sticky Accept error must NOT cause
// CPU spin. The loop backs off and exits cleanly when ctx is cancelled.
func TestFeedRouterFromListener_BacksOffOnTransientError(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
router := nbtcp.NewRouter(logger, nil, addr)
const transientCount = 5
errs := make([]error, transientCount)
for i := range errs {
errs[i] = errSentinel("transient: temporary network error")
}
ln := newScriptedAcceptListener(errs...)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
start := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
feedRouterFromListener(ctx, ln, router, logger, "acct-1")
}()
time.AfterFunc(150*time.Millisecond, cancel)
select {
case <-done:
// Expected.
case <-time.After(2 * time.Second):
t.Fatal("feedRouterFromListener did not exit on ctx cancellation — backoff or exit path broken")
}
// Without backoff the 5 scripted errors would burn in microseconds.
// With backoff the first delay alone is 5ms, so the loop must take
// at least that long even though ctx fires at 150ms.
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, 5*time.Millisecond,
"loop ran without backing off — would burn CPU in production")
}

View File

@@ -356,7 +356,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
wgPort := int(n.clientCfg.WGPort)
embedOpts := embed.Options{
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
@@ -371,9 +371,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
Performance: n.clientCfg.Performance,
}
logEmbedOptions(n.logger, accountID, serviceID, publicKey.String(), embedOpts)
client, err := embed.New(embedOpts)
})
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
}
@@ -849,53 +847,3 @@ func DirectUpstreamFromContext(ctx context.Context) bool {
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
return v
}
// logEmbedOptions emits a single structured INFO line summarising every
// operationally meaningful flag handed to embed.New for this per-account
// client. Secrets (PrivateKey, PreSharedKey) are reduced to a "present"
// boolean — never logged verbatim. Use this when an embedded peer
// silently misbehaves: most failure modes (inbound drops, wrong
// management URL, v6 unexpectedly on, userspace flipped, port clash)
// are obvious from these flags before any traffic flows.
func logEmbedOptions(logger *log.Logger, accountID types.AccountID, serviceID types.ServiceID, publicKey string, opts embed.Options) {
wgPort := 0
if opts.WireguardPort != nil {
wgPort = *opts.WireguardPort
}
mtu := uint16(0)
if opts.MTU != nil {
mtu = *opts.MTU
}
perfBuffers := uint32(0)
if opts.Performance.PreallocatedBuffersPerPool != nil {
perfBuffers = *opts.Performance.PreallocatedBuffersPerPool
}
perfBatch := uint32(0)
if opts.Performance.MaxBatchSize != nil {
perfBatch = *opts.Performance.MaxBatchSize
}
logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
"public_key": publicKey,
"device_name": opts.DeviceName,
"management_url": opts.ManagementURL,
"log_level": opts.LogLevel,
"wg_port": wgPort,
"mtu": mtu,
"block_inbound": opts.BlockInbound,
"block_lan_access": opts.BlockLANAccess,
"disable_ipv6": opts.DisableIPv6,
"disable_client_routes": opts.DisableClientRoutes,
"no_userspace": opts.NoUserspace,
"config_path_set": opts.ConfigPath != "",
"state_path_set": opts.StatePath != "",
"private_key_present": opts.PrivateKey != "",
"presharedkey_present": opts.PreSharedKey != "",
"setup_key_present": opts.SetupKey != "",
"jwt_token_present": opts.JWTToken != "",
"dns_labels": opts.DNSLabels,
"perf_buffers_per_pool": perfBuffers,
"perf_max_batch_size": perfBatch,
}).Info("starting embedded netbird client for account")
}

View File

@@ -1,85 +0,0 @@
package tcp
import (
"context"
"errors"
"net"
"strings"
"time"
)
// gvisorInvalidEndpointMsg is the canonical text gVisor netstack returns
// when Accept() is called on a listener whose underlying endpoint has
// been destroyed (peer rekey, embedded-client reset, account churn).
// There is no exported sentinel from gvisor.dev/gvisor/pkg/tcpip that
// survives gonet's *net.OpError wrapping in a way errors.Is can match,
// so we fall back to a string check. Stable across the gVisor versions
// netbird pins.
const gvisorInvalidEndpointMsg = "endpoint is in invalid state"
// IsClosedListenerErr reports whether err signals that an accept loop
// should exit because the underlying listener can no longer serve
// connections. It recognises:
//
// - net.ErrClosed for stdlib listeners (Listener.Close was called).
// - gVisor's "endpoint is in invalid state" for netstack-backed
// listeners whose endpoint was destroyed out from under them
// (typically when a per-account WireGuard netstack is reset without
// also tearing the listener entry down).
//
// Without the gVisor branch an accept loop on a netstack listener spins
// CPU-hot forever after the endpoint dies, because Accept never blocks
// again and the error neither matches net.ErrClosed nor cancels ctx.
func IsClosedListenerErr(err error) bool {
if err == nil {
return false
}
if errors.Is(err, net.ErrClosed) {
return true
}
return strings.Contains(err.Error(), gvisorInvalidEndpointMsg)
}
// AcceptBackoff implements the exponential backoff used by
// net/http.Server.Serve for transient Accept errors. Without it a loop
// hitting a sticky unknown error burns a full CPU core. The zero value
// is ready to use; call Reset after a successful Accept.
type AcceptBackoff struct {
delay time.Duration
}
// minAcceptDelay / maxAcceptDelay mirror the stdlib defaults
// (net/http.Server.Serve) and keep us well below 1 log line per second
// per orphaned listener.
const (
minAcceptDelay = 5 * time.Millisecond
maxAcceptDelay = time.Second
)
// Backoff waits the next exponential delay (5ms doubling up to 1s) and
// returns true when the wait completed. Returns false if ctx fired
// during the wait — callers should treat that as "exit the loop".
func (b *AcceptBackoff) Backoff(ctx context.Context) bool {
b.advance()
select {
case <-ctx.Done():
return false
case <-time.After(b.delay):
return true
}
}
// Reset clears the accumulated delay so the next failure starts at the
// minimum delay again. Call after a successful Accept.
func (b *AcceptBackoff) Reset() { b.delay = 0 }
func (b *AcceptBackoff) advance() {
if b.delay == 0 {
b.delay = minAcceptDelay
} else {
b.delay *= 2
}
if b.delay > maxAcceptDelay {
b.delay = maxAcceptDelay
}
}

View File

@@ -1,142 +0,0 @@
package tcp
import (
"context"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIsClosedListenerErr_NetErrClosed verifies the stdlib path: a
// closed *net.Listener returns net.ErrClosed wrapped in *net.OpError,
// and IsClosedListenerErr must unwrap it.
func TestIsClosedListenerErr_NetErrClosed(t *testing.T) {
wrapped := &net.OpError{Op: "accept", Net: "tcp", Err: net.ErrClosed}
assert.True(t, IsClosedListenerErr(wrapped),
"net.OpError wrapping net.ErrClosed must be recognised as closed")
}
// TestIsClosedListenerErr_GVisorInvalidEndpoint is the load-bearing
// regression guard. A gVisor netstack listener whose endpoint has been
// destroyed returns this exact text. Without recognising it the accept
// loop spins forever and burns a CPU core.
func TestIsClosedListenerErr_GVisorInvalidEndpoint(t *testing.T) {
err := fmt.Errorf("accept tcp 10.10.1.254:80: endpoint is in invalid state")
assert.True(t, IsClosedListenerErr(err),
"gVisor 'endpoint is in invalid state' must be recognised as closed")
}
// TestIsClosedListenerErr_OtherError confirms we don't over-match —
// transient errors must keep returning false so the backoff path runs.
func TestIsClosedListenerErr_OtherError(t *testing.T) {
cases := []error{
errors.New("temporary failure"),
errors.New("accept tcp 10.10.1.254:80: too many open files"),
nil,
}
for _, c := range cases {
assert.False(t, IsClosedListenerErr(c),
"unexpected match on %v — must not be treated as closed", c)
}
}
// TestAcceptBackoff_ProgressionAndCap asserts the doubling schedule:
// 5ms, 10ms, 20ms, 40ms, ... capped at 1s. The test runs against a
// real timer but uses tight bounds so a slow CI machine still passes.
func TestAcceptBackoff_ProgressionAndCap(t *testing.T) {
var b AcceptBackoff
expected := []time.Duration{
5 * time.Millisecond,
10 * time.Millisecond,
20 * time.Millisecond,
40 * time.Millisecond,
}
for i, want := range expected {
start := time.Now()
ok := b.Backoff(context.Background())
elapsed := time.Since(start)
require.True(t, ok, "Backoff %d must complete; ctx is alive", i)
assert.GreaterOrEqual(t, elapsed, want,
"backoff %d (%v) must wait at least the configured delay", i, want)
assert.Less(t, elapsed, want*4,
"backoff %d (%v) must not overshoot by more than 4x — caps misbehaving", i, want)
}
// Burn enough rounds to reach the cap, then assert subsequent
// rounds stay at exactly maxAcceptDelay (1s) — the timer should
// never exceed it.
for range 6 {
b.Backoff(context.Background())
}
assert.Equal(t, maxAcceptDelay, b.delay,
"after enough doublings the delay must clamp to maxAcceptDelay")
}
// TestAcceptBackoff_Reset confirms that a successful Accept resets the
// schedule — a busy-then-quiet listener mustn't stay on a 1s timer
// after recovery.
func TestAcceptBackoff_Reset(t *testing.T) {
var b AcceptBackoff
for range 5 {
b.Backoff(context.Background())
}
require.NotEqual(t, time.Duration(0), b.delay, "precondition: delay must have accumulated")
b.Reset()
assert.Equal(t, time.Duration(0), b.delay, "Reset must zero the delay")
start := time.Now()
ok := b.Backoff(context.Background())
elapsed := time.Since(start)
require.True(t, ok, "Backoff after Reset must complete")
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
"after Reset the next backoff must restart at minAcceptDelay")
assert.Less(t, elapsed, 50*time.Millisecond,
"after Reset the next backoff must NOT carry over the prior delay")
}
// TestAcceptBackoff_CancelDuringWait proves the loop exits promptly
// when ctx fires mid-wait. Without this, a tear-down would still take
// up to 1 second per orphaned listener.
func TestAcceptBackoff_CancelDuringWait(t *testing.T) {
var b AcceptBackoff
// Drive the backoff up so the next call will wait ~1s — long
// enough that we can detect early cancellation.
for range 10 {
b.Backoff(context.Background())
}
require.Equal(t, maxAcceptDelay, b.delay)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
start := time.Now()
ok := b.Backoff(ctx)
elapsed := time.Since(start)
assert.False(t, ok, "Backoff must return false when ctx is cancelled mid-wait")
assert.Less(t, elapsed, 200*time.Millisecond,
"cancellation must short-circuit the timer; took %v", elapsed)
}
// TestAcceptBackoff_CancelBeforeCall — when ctx is already done the
// loop exits without sleeping at all.
func TestAcceptBackoff_CancelBeforeCall(t *testing.T) {
var b AcceptBackoff
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
ok := b.Backoff(ctx)
elapsed := time.Since(start)
assert.False(t, ok, "Backoff must return false when ctx is already cancelled")
assert.Less(t, elapsed, 50*time.Millisecond,
"already-cancelled ctx must return immediately; took %v", elapsed)
}

View File

@@ -297,23 +297,18 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
}
}()
var backoff AcceptBackoff
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || IsClosedListenerErr(err) {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
if ok := r.Drain(DefaultDrainTimeout); !ok {
r.logger.Warn("timed out waiting for connections to drain")
}
return nil
}
r.logger.Debugf("SNI router accept: %v; backing off", err)
if !backoff.Backoff(ctx) {
return nil
}
r.logger.Debugf("SNI router accept: %v", err)
continue
}
backoff.Reset()
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
r.activeConns.Add(1)
go func() {

View File

@@ -1836,132 +1836,3 @@ func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
t.Fatal("TLS conn never reached the TLS channel")
}
}
// scriptedAcceptListener is a net.Listener whose Accept() returns
// pre-scripted errors. Used by the accept-loop exit tests to simulate
// the failure mode that triggers the tight-loop bug: a netstack
// listener whose endpoint has been destroyed and now returns the gVisor
// "endpoint is in invalid state" error from every Accept call.
type scriptedAcceptListener struct {
errs chan error
closed chan struct{}
}
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
s := &scriptedAcceptListener{
errs: make(chan error, len(errs)+1),
closed: make(chan struct{}),
}
for _, e := range errs {
s.errs <- e
}
return s
}
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
select {
case <-s.closed:
return nil, net.ErrClosed
case err := <-s.errs:
return nil, err
}
}
func (s *scriptedAcceptListener) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
func (s *scriptedAcceptListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
// TestRouter_Serve_ExitsOnGVisorInvalidEndpoint is the regression guard
// for the tight-loop bug: when the underlying netstack endpoint is
// destroyed, Accept returns "endpoint is in invalid state" forever. The
// loop must recognise that signal and return, otherwise it pegs a CPU
// core and floods logs.
func TestRouter_Serve_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
gvisorErr := &net.OpError{
Op: "accept",
Net: "tcp",
Addr: addr,
Err: errSentinel("endpoint is in invalid state"),
}
ln := newScriptedAcceptListener(gvisorErr)
defer ln.Close()
done := make(chan error, 1)
go func() {
done <- router.Serve(context.Background(), ln)
}()
select {
case err := <-done:
assert.NoError(t, err, "Serve must return cleanly on a recognised closed-listener error")
case <-time.After(2 * time.Second):
t.Fatal("Serve did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
}
}
// TestRouter_Serve_BacksOffOnTransientError verifies the defence-in-
// depth path: when Accept returns an unknown transient error, the loop
// MUST not spin. It backs off, then exits cleanly once ctx is cancelled.
// "Bounded call count" stands in for "no CPU spin" — without backoff
// the goroutine would issue thousands of Accept calls in this window.
func TestRouter_Serve_BacksOffOnTransientError(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
const transientErrCount = 5
errs := make([]error, transientErrCount)
for i := range errs {
errs[i] = errSentinel("transient: too many open files")
}
ln := newScriptedAcceptListener(errs...)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
start := time.Now()
go func() {
done <- router.Serve(ctx, ln)
}()
// Cancel after enough time for the backoff to climb (5ms + 10ms +
// 20ms + 40ms = 75ms minimum), but short enough that a spinning
// loop would have made thousands of calls by now.
time.AfterFunc(150*time.Millisecond, cancel)
select {
case err := <-done:
assert.NoError(t, err, "Serve must return cleanly on ctx cancellation")
case <-time.After(2 * time.Second):
t.Fatal("Serve did not exit on ctx cancellation — backoff or exit path broken")
}
// Without backoff the loop would burn through all 5 scripted errors
// in microseconds and then block on the channel. With backoff the
// total wall time should be at least 5ms (the first backoff).
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
"loop ran without backing off — would burn CPU in production")
}
// errSentinel mirrors gVisor's tcpip error message exactly. We can't
// import the gVisor package without dragging in the whole netstack, so
// the test uses the canonical string the production error formatter
// emits — same shape IsClosedListenerErr matches in production.
type errSentinel string
func (e errSentinel) Error() string { return string(e) }

View File

@@ -125,6 +125,7 @@ func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
oidcConfig,
nil,
usersManager,
nil,
realProxyManager,
nil,
)

View File

@@ -140,6 +140,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
oidcConfig,
nil,
usersManager,
nil,
proxyManager,
nil,
)

View File

@@ -21,7 +21,8 @@ AWK_FIRST_FIELD='{print $1}'
fetch_all_tags() {
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
grep -iv 'rc' | \
sed 's/.*\/v//' | \
sort -u -V
return 0

View File

@@ -32,7 +32,8 @@ fetch_current_ports_version() {
fetch_all_tags() {
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
grep -iv 'rc' | \
sed 's/.*\/v//' | \
sort -u -V
return 0

View File

@@ -26,6 +26,10 @@ type Peer struct {
// a gRpc connection stream to the Peer
Stream proto.SignalExchange_ConnectStreamServer
// sendMu serializes writes to Stream. gRPC forbids concurrent SendMsg on
// the same ServerStream, and a peer can be the target of many senders at
// once.
sendMu sync.Mutex
// registration time
RegisteredAt time.Time
@@ -33,6 +37,13 @@ type Peer struct {
Cancel context.CancelFunc
}
// Send writes a message to the peer's stream, serializing concurrent senders.
func (p *Peer) Send(msg *proto.EncryptedMessage) error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.Stream.Send(msg)
}
// NewPeer creates a new instance of a connected Peer
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
return &Peer{

View File

@@ -0,0 +1,67 @@
package server
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/peer"
)
// concurrencyCheckStream records the maximum number of Send calls in flight at
// once. gRPC forbids concurrent SendMsg on the same ServerStream, so a correct
// server must never have more than one in flight per peer.
type concurrencyCheckStream struct {
proto.SignalExchange_ConnectStreamServer
ctx context.Context
inflight atomic.Int32
maxSeen atomic.Int32
}
func (s *concurrencyCheckStream) Send(*proto.EncryptedMessage) error {
n := s.inflight.Add(1)
for {
old := s.maxSeen.Load()
if n <= old || s.maxSeen.CompareAndSwap(old, n) {
break
}
}
// Widen the window so overlapping callers are reliably observed.
time.Sleep(time.Millisecond)
s.inflight.Add(-1)
return nil
}
func (s *concurrencyCheckStream) Context() context.Context { return s.ctx }
// TestForwardMessageToPeerSerializesSend verifies that concurrent forwards to the
// same peer never call Stream.Send concurrently, which would violate the gRPC
// ServerStream contract.
func TestForwardMessageToPeerSerializesSend(t *testing.T) {
s, err := NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
const peerID = "peerX"
stream := &concurrencyCheckStream{ctx: context.Background()}
_, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
require.NoError(t, s.registry.Register(peer.NewPeer(peerID, stream, cancel)))
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s.forwardMessageToPeer(context.Background(), &proto.EncryptedMessage{Key: "sender", RemoteKey: peerID})
}()
}
wg.Wait()
require.Equal(t, int32(1), stream.maxSeen.Load(), "Stream.Send must never run concurrently on the same peer stream")
}

View File

@@ -179,7 +179,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
sendResultChan := make(chan error, 1)
go func() {
select {
case sendResultChan <- dstPeer.Stream.Send(msg):
case sendResultChan <- dstPeer.Send(msg):
return
case <-dstPeer.Stream.Context().Done():
return