Compare commits

...

11 Commits

Author SHA1 Message Date
Viktor Liu
58c79f5878 [client] Fix DNS custom zone teardown: handler leak and external CNAME resolution (#6445) 2026-06-19 17:33:09 +02:00
Viktor Liu
15a0504fb1 [client] Treat answering upstreams as reachable and widen DNS health grace window (#6453) 2026-06-19 17:32:49 +02:00
Riccardo Manfrin
883a1a8961 [client] Fix profile regressions in up --profile and status (#6479)
* Restores behavior to create profile if not there on Up

* Allows to restore nerbird status showing of the profile name

* [client] Reduce upFunc cognitive complexity

Extract the profile switch/auto-create logic from upFunc into a dedicated
switchOrCreateProfile helper. The inlined NotFound-retry branch pushed
upFunc over SonarCloud's cognitive complexity threshold (S3776).
No behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* [client] Make up --profile auto-create idempotent under concurrent runs

Don't fail switchOrCreateProfile on a createProfile error: a concurrent
run may create the profile between the NotFound check and our create
call. Retry the switch regardless and only surface the create error if
the switch also fails. Addresses CodeRabbit race-condition feedback.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* Share createProfile with addProfileFunc

* But allow conn reusage

* moves switchOrCreateProfile to where it's used

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-19 16:23:51 +02:00
Maycon Santos
54192a94b7 [misc] handle release candidates when fetching tags in FreeBSD port scripts (#6480)
* [misc] Exclude release candidates when fetching tags in FreeBSD port scripts
2026-06-19 14:10:43 +02:00
Pascal Fischer
8511687270 [management] log peer meta diff (#6468) 2026-06-19 13:30:52 +02:00
Pascal Fischer
35b465fa4a [management] reduce sync and login transaction (#6472) 2026-06-19 11:43:01 +02:00
Brad Ison
fb87f751a5 [management] Fetch complete user data in ValidateTunnelPeer (#6457)
* [management] Fetch complete user data in ValidateTunnelPeer

Previously the `ValidateTunnelPeer` method used by the ProxyService
would fetch user information from the database if the connected peer
was associated with a user ID, but it would not consult the IdP data
for cached info from JWT claims like email.  This caused the value of
the injected `X-Netbird-User` header to always display the peer ID and
never the user email associated with the peer as expected.

This change adds an optional IdP manager to the ProxyService and
fetches the complete user data from it if present.

* [management] Refactor ValidateTunnelPeer principal info gathering

This refactors the gathering of info on proxy tunnel peer principals
into its own method to keep the complexity down and make Sonar happy.
2026-06-19 11:39:21 +02:00
Maycon Santos
679c7182a4 [misc] Remove version prefix v docker tags (#6471) 2026-06-18 22:34:24 +02:00
Pascal Fischer
8c031ea6f0 [management] remove db calls in nested loops (#6470) 2026-06-18 22:12:59 +02:00
Pascal Fischer
60a9544656 [management] pass meta update for browser clients (#6465) 2026-06-18 17:22:42 +02:00
Viktor Liu
d3710d4bb2 [signal] Serialize concurrent sends to a peer signal stream (#6463) 2026-06-18 15:00:19 +02:00
27 changed files with 1041 additions and 284 deletions

View File

@@ -247,7 +247,7 @@ dockers_v2:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile
extra_files:
@@ -295,7 +295,7 @@ dockers_v2:
- netbirdio/relay
- ghcr.io/netbirdio/relay
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile
platforms:
@@ -317,7 +317,7 @@ dockers_v2:
- netbirdio/signal
- ghcr.io/netbirdio/signal
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile
platforms:
@@ -339,7 +339,7 @@ dockers_v2:
- netbirdio/management
- ghcr.io/netbirdio/management
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile
platforms:
@@ -361,7 +361,7 @@ dockers_v2:
- netbirdio/upload
- ghcr.io/netbirdio/upload
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile
platforms:
@@ -383,7 +383,7 @@ dockers_v2:
- netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile
platforms:
@@ -405,7 +405,7 @@ dockers_v2:
- netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile
platforms:

View File

@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
Username: &username,
})
if err != nil {
return "", fmt.Errorf("switch profile failed: %v", err)
return "", fmt.Errorf("switch profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil

View File

@@ -138,26 +138,23 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
return err
}
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
if err != nil {
return fmt.Errorf("add profile request: %w", err)
return err
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
@@ -166,7 +163,6 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
@@ -330,3 +326,19 @@ func wrapAmbiguityError(err error, handle string) error {
}
return err
}
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
// and returns the new profile's ID. It is the single entry point for profile
// creation, shared by `netbird profile add` and the `netbird up --profile
// <name>` auto-create path.
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
ProfileName: profileName,
Username: username,
})
if err != nil {
return "", fmt.Errorf("add profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil
}

View File

@@ -11,7 +11,6 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
@@ -111,11 +110,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
// Resolve the active profile's display name via the daemon, which runs
// as root and can read the per-user profile files. The local profile
// manager only knows the active profile ID, not its display name.
profName := getActiveProfileName(ctx)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
@@ -167,6 +165,25 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
return resp, nil
}
// getActiveProfileName asks the daemon for the active profile's display
// name. The daemon runs as root and can read the per-user profile files to
// resolve the ID to its human-readable name. Returns an empty string on any
// error so status output degrades gracefully.
func getActiveProfileName(ctx context.Context) string {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return ""
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
if err != nil {
return ""
}
return resp.GetProfileName()
}
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected":

View File

@@ -128,15 +128,9 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
profileSwitched = true
}
@@ -151,6 +145,52 @@ func upFunc(cmd *cobra.Command, args []string) error {
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
}
// switchOrCreateProfile switches the active profile to the one identified by
// handle, creating it first when it does not exist yet. This restores the
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
// missing profile instead of failing.
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
st, ok := gstatus.FromError(err)
if !ok || st.Code() != codes.NotFound {
return err
}
// Don't fail immediately on a create error: a concurrent run may
// have created the profile between the NotFound above and this
// call, in which case the retried switch still succeeds. Only
// surface the create error if the switch also fails.
_, createErr := createProfile(ctx, handle, username)
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
if createErr != nil {
return fmt.Errorf("create profile: %w", createErr)
}
return err
}
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return err
}
return nil
}
// createProfile dials the daemon and creates a new profile with the given
// display name, returning its generated ID. Use addProfileOnDaemon directly
// when a daemon client is already available to reuse the connection.
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
}
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
// override the default profile filepath if provided
if configPath != "" {

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
@@ -38,11 +39,15 @@ const (
// defaultWarningDelayBase is the starting grace window before a
// "Nameserver group unreachable" event fires for a group that's
// never been healthy and only has overlay upstreams with no
// Connected peer. Per-server and overridable; see warningDelayFor.
defaultWarningDelayBase = 30 * time.Second
// Connected peer. Per-server and overridable via envWarningDelay;
// see warningDelay.
defaultWarningDelayBase = 60 * time.Second
// warningDelayBonusCap caps the route-count bonus added to the
// base grace window. See warningDelayFor.
// base grace window. See warningDelay.
warningDelayBonusCap = 30 * time.Second
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
)
// errNoUsableNameservers signals that a merged-domain group has no usable
@@ -135,7 +140,7 @@ type DefaultServer struct {
disableSys bool
mux sync.Mutex
service service
dnsMuxMap registeredHandlerMap
dnsMuxHandlers []handlerWrapper
localResolver *local.Resolver
wgInterface WGIface
hostManager hostManager
@@ -199,8 +204,6 @@ type handlerWrapper struct {
priority int
}
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
WgInterface WGIface
@@ -289,7 +292,6 @@ func newDefaultServer(
service: dnsService,
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(),
wgInterface: wgInterface,
statusRecorder: statusRecorder,
@@ -298,7 +300,7 @@ func newDefaultServer(
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
warningDelayBase: defaultWarningDelayBase,
warningDelayBase: warningDelayBaseFromEnv(),
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
@@ -328,7 +330,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
type routeSettable interface {
setSelectedRoutes(func() route.HAMap)
}
for _, entry := range s.dnsMuxMap {
for _, entry := range s.dnsMuxHandlers {
if h, ok := entry.handler.(routeSettable); ok {
h.setSelectedRoutes(selected)
}
@@ -978,19 +980,23 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap {
for _, existing := range s.dnsMuxHandlers {
s.deregisterHandler([]string{existing.domain}, existing.priority)
existing.handler.Stop()
// The local resolver is a persistent singleton shared by every custom
// zone and reused across config updates. Its chain registrations are
// per-config and must be deregistered, but Stop() cancels its lookup
// context (breaking external CNAME-target resolution) and clears its
// records, so it must not be torn down here.
if existing.handler != s.localResolver {
existing.handler.Stop()
}
}
muxUpdateMap := make(registeredHandlerMap)
for _, update := range muxUpdates {
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update
}
s.dnsMuxMap = muxUpdateMap
s.dnsMuxHandlers = muxUpdates
}
// updateNSGroupStates records the new group set and pokes the refresher.
@@ -1154,6 +1160,26 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
return false
}
// warningDelayBaseFromEnv returns the base grace window, honoring
// envWarningDelay when it holds a valid positive Go duration. Invalid or
// non-positive values fall back to defaultWarningDelayBase.
func warningDelayBaseFromEnv() time.Duration {
val := os.Getenv(envWarningDelay)
if val == "" {
return defaultWarningDelayBase
}
d, err := time.ParseDuration(val)
if err != nil {
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
return defaultWarningDelayBase
}
if d <= 0 {
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
return defaultWarningDelayBase
}
return d
}
// warningDelay returns the grace window for the given selected-route
// count. Scales gently: +1s per 100 routes, capped by
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
@@ -1204,7 +1230,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
// in more than one handler.
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
merged := make(map[netip.AddrPort]UpstreamHealth)
for _, entry := range s.dnsMuxMap {
for _, entry := range s.dnsMuxHandlers {
reporter, ok := entry.handler.(upstreamHealthReporter)
if !ok {
continue

View File

@@ -104,19 +104,6 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
@@ -132,22 +119,20 @@ func TestUpdateDNSServer(t *testing.T) {
},
}
dummyHandler := local.NewResolver()
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registeredHandlerMap
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -169,20 +154,17 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
dummyHandler.ID(): handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
@@ -191,10 +173,10 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
@@ -215,15 +197,13 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
"local-resolver": handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
},
@@ -232,7 +212,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
@@ -240,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -262,7 +242,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -284,7 +264,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -301,42 +281,41 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap),
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap),
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
@@ -393,7 +372,7 @@ func TestUpdateDNSServer(t *testing.T) {
}
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
@@ -405,14 +384,20 @@ func TestUpdateDNSServer(t *testing.T) {
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
@@ -512,8 +497,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
}()
dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
@@ -1029,15 +1014,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {}
func TestDefaultServer_UpdateMux(t *testing.T) {
baseMatchHandlers := registeredHandlerMap{
"upstream-group1": {
baseMatchHandlers := []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
},
"upstream-group2": {
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
@@ -1046,15 +1031,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
},
}
baseRootHandlers := registeredHandlerMap{
"upstream-root1": {
baseRootHandlers := []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
"upstream-root2": {
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
@@ -1063,22 +1048,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
},
}
baseMixedHandlers := registeredHandlerMap{
"upstream-group1": {
baseMixedHandlers := []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
},
"upstream-group2": {
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
},
"upstream-other": {
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
@@ -1089,7 +1074,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
tests := []struct {
name string
initialHandlers registeredHandlerMap
initialHandlers []handlerWrapper
updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain
description string
@@ -1373,32 +1358,38 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &DefaultServer{
dnsMuxMap: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
dnsMuxHandlers: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Perform the update
server.updateMux(tt.updates)
// Verify the results
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
"Number of handlers after update doesn't match expected")
// Check each expected handler
for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id)
if exists {
assert.Equal(t, expectedDomain, handler.domain,
var found *handlerWrapper
for i := range server.dnsMuxHandlers {
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
found = &server.dnsMuxHandlers[i]
break
}
}
assert.NotNil(t, found, "Expected handler %s not found", id)
if found != nil {
assert.Equal(t, expectedDomain, found.domain,
"Domain mismatch for handler %s", id)
}
}
// Verify no unexpected handlers exist
for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
for _, entry := range server.dnsMuxHandlers {
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
}
// Verify the handlerChain state and order
@@ -1413,7 +1404,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Verify handler exists in mux
foundInMux := false
for _, muxEntry := range server.dnsMuxMap {
for _, muxEntry := range server.dnsMuxHandlers {
if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
@@ -1422,12 +1413,108 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}
}
assert.True(t, foundInMux,
"Handler in chain not found in dnsMuxMap")
"Handler in chain not found in dnsMuxHandlers")
}
})
}
}
// chainHasPattern reports whether the handler chain holds an entry registered
// for the given fqdn pattern at the given priority.
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
for _, h := range s.handlerChain.handlers {
if h.OrigPattern == pattern && h.Priority == priority {
return true
}
}
return false
}
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
// tracks each (handler, domain) registration independently when one handler
// serves multiple zones. Every custom zone is served by the same handler
// instance (the local resolver, whose ID is the constant "local-resolver"), so
// removing one zone must deregister exactly that zone's chain entry and leave
// the others in place. Tracking registrations by handler ID alone collapses all
// zones onto one entry, leaving removed zones in the chain to answer
// authoritatively with no records.
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
// One handler serves every custom zone, mirroring s.localResolver.
shared := &mockHandler{Id: "local-resolver"}
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Two custom zones under the same handler. The surviving zone is registered
// last, mirroring the management emission order.
server.updateMux([]handlerWrapper{
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test should be registered after the first update")
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should be registered after the first update")
// Remove one zone, keep the other.
server.updateMux([]handlerWrapper{
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should remain after removing userzone.test")
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test handler must be deregistered, not leaked in the chain")
}
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
// does not tear down the shared local resolver during reconfiguration. The
// resolver is a process-lifetime singleton reused across config updates;
// Stop() cancels its lookup context (breaking external CNAME-target
// resolution) and clears its records. updateMux must deregister its chain
// entries without stopping it. Records surviving a teardown update is the
// observable proxy: Stop() would have cleared them.
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
resolver := local.NewResolver()
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
Name: "peer.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}))
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
localResolver: resolver,
}
server.updateMux([]handlerWrapper{
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
})
// Remove the zone. The resolver must survive so its records and lookup
// context stay intact for the next registration.
server.updateMux(nil)
var response *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
response = m
return nil
},
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
require.NotNil(t, response, "local resolver should answer after teardown")
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
}
func TestExtraDomains(t *testing.T) {
tests := []struct {
name string
@@ -2049,7 +2136,6 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
@@ -2207,7 +2293,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
}
}
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
@@ -2283,12 +2369,11 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
@@ -2395,7 +2480,6 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
@@ -2407,7 +2491,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2444,7 +2528,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
@@ -2459,7 +2542,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2484,6 +2567,32 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestWarningDelayBaseFromEnv(t *testing.T) {
tests := []struct {
name string
set bool
val string
want time.Duration
}{
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envWarningDelay, tc.val)
if !tc.set {
os.Unsetenv(envWarningDelay)
}
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
})
}
}
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
@@ -2595,7 +2704,6 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
@@ -2613,7 +2721,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2640,7 +2748,6 @@ func TestDNSLoopPrevention(t *testing.T) {
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
tests := []struct {

View File

@@ -443,21 +443,25 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
// A valid response means the upstream is reachable, whatever the Rcode.
u.markUpstreamOk(upstream)
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
// refused zones, transient recursion errors), not reachability
// problems: fail over for a better answer but keep the upstream healthy.
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
@@ -465,7 +469,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}

View File

@@ -517,6 +517,78 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
// those are per-question outcomes from a reachable server and must not mark
// the upstream unhealthy. Only transport failures (timeouts) do.
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.10:53")
b := netip.MustParseAddrPort("192.0.2.11:53")
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
tests := []struct {
name string
respA mockUpstreamResponse
respB mockUpstreamResponse
wantHealthy bool
}{
{
name: "both SERVFAIL are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
wantHealthy: true,
},
{
name: "both REFUSED are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
wantHealthy: true,
},
{
name: "timeout marks unhealthy",
respA: mockUpstreamResponse{err: timeoutErr},
respB: mockUpstreamResponse{err: timeoutErr},
wantHealthy: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): tc.respA,
b.String(): tc.respB,
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a, b})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, a, "primary upstream should have a health record")
if tc.wantHealthy {
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
} else {
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
}
})
}
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
return srv
}
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -33,6 +33,8 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
proxyauth "github.com/netbirdio/netbird/proxy/auth"
@@ -82,6 +84,9 @@ type ProxyServiceServer struct {
// Manager for users
usersManager users.Manager
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
idpManager idp.Manager
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
@@ -157,7 +162,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -166,6 +171,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
idpManager: idpManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
@@ -1702,22 +1708,7 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
@@ -1754,6 +1745,45 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}, nil
}
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
// user or peer ID, and peer name or user email.
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
// Resolve the principal: when the peer is linked to a user, the human is the
// principal so multiple peers owned by the same user share a single
// identity. Unlinked peers (machine agents) are their own principal keyed on
// peer.ID. displayIdentity is what upstream gateways tag spend with —
// user.Email when linked, peer.Name when not.
// If the peer isn't associated with a user, return the peer info directly.
if peer.UserID == "" {
return peer.ID, peer.Name
}
// Otherwise, if the peer is linked to a user, the user is the principal and
// if an IdP is available, we gather details on the user from it.
principalID := peer.UserID
displayIdentity := peer.Name
// Stored column first (cheap, but often empty for OIDC-provisioned users).
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
// IdP enrichment wins when available — the stored email column is a
// best-effort cache and is frequently empty for OIDC users. Enrichment
// failures must never fail the RPC; we simply keep the stored/peer identity.
if s.idpManager != nil {
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
displayIdentity = ud.Email
} else if uerr != nil {
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
}
}
return principalID, displayIdentity
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security

View File

@@ -3,14 +3,19 @@ package grpc
import (
"context"
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockReverseProxyManager struct {
@@ -137,6 +142,52 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
return user, nil, nil
}
// mockTunnelPeersManager implements only the two peers.Manager methods that
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
// panics if any unexpected method is invoked).
type mockTunnelPeersManager struct {
peers.Manager
peer *peer.Peer
peerErr error
groups []*types.Group
groupsErr error
}
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
return m.peer, m.peerErr
}
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
return m.peer, m.groups, m.groupsErr
}
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
// an IdP that knows nothing about the user.
type mockTunnelIdpManager struct {
idp.Manager
email string
hasData bool
err error
gotCalls int
gotMeta []idp.AppMetadata
}
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
m.gotCalls++
m.gotMeta = append(m.gotMeta, meta)
if m.err != nil {
return nil, m.err
}
if !m.hasData {
// This might not be a thing any of the actual IDP implementations do,
// i.e. return a nil value with no error, but it seems valuable to test
// that behavior here.
return nil, nil //nolint:nilnil
}
return &idp.UserData{ID: userID, Email: m.email}, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
@@ -354,6 +405,163 @@ func TestValidateUserGroupAccess(t *testing.T) {
}
}
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
// (IdP email -> stored User.Email -> peer.Name).
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
const (
domain = "app.example.com"
accountID = "account1"
peerID = "peer1"
peerName = "peer-display-name"
userID = "user1"
)
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
tests := []struct {
name string
peerUserID string
storedUsers map[string]*types.User
storedErr error
noIdP bool
idpEmail string
idpHasData bool
idpErr error
expectEmail string
expectUserID string
expectIdPHit bool
}{
{
name: "idp email wins over stored email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp returns empty email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "",
idpHasData: true,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp has no data",
peerUserID: userID,
storedUsers: storedUser,
idpHasData: false,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp errors",
peerUserID: userID,
storedUsers: storedUser,
idpErr: errors.New("idp unreachable"),
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when no idp manager",
peerUserID: userID,
storedUsers: storedUser,
noIdP: true,
expectEmail: "stored@example.com",
expectUserID: userID,
},
{
name: "idp email when stored email is empty",
peerUserID: userID,
storedUsers: storedUserNoEmail,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "idp email when stored user missing keeps peer.UserID as principal",
peerUserID: userID,
storedUsers: map[string]*types.User{},
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "unlinked peer uses peer name and never consults idp",
peerUserID: "",
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: peerName,
expectUserID: peerID,
expectIdPHit: false,
},
{
name: "linked peer with empty stored email and no idp falls back to peer name",
peerUserID: userID,
storedUsers: storedUserNoEmail,
noIdP: true,
expectEmail: peerName,
expectUserID: userID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := &service.Service{Domain: domain, AccountID: accountID}
server := &ProxyServiceServer{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
},
peersManager: &mockTunnelPeersManager{
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
},
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
}
var idpMock *mockTunnelIdpManager
if !tt.noIdP {
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
server.idpManager = idpMock
}
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
Domain: domain,
TunnelIp: "100.64.0.1",
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.True(t, resp.GetValid(), "expected access granted")
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
assert.Equal(t, tt.expectUserID, resp.GetUserId())
if idpMock != nil {
if tt.expectIdPHit {
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
require.Len(t, idpMock.gotMeta, 1)
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
} else {
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
}
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)

View File

@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err

View File

@@ -217,6 +217,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager,
nil,
nil,
nil,
)
proxyService.SetServiceManager(&testServiceManager{store: testStore})

View File

@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {

View File

@@ -982,8 +982,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var peer *nbpeer.Peer
var updated, versionChanged, ipv6CapabilityChanged bool
var err error
var postureChecks []*posture.Checks
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -1011,13 +1009,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return status.NewPeerLoginExpiredError()
}
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
@@ -1025,11 +1018,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err
}
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
}
return nil
})
@@ -1037,6 +1025,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
@@ -1047,9 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1124,7 +1117,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
var peer *nbpeer.Peer
var shouldStorePeer bool
var shouldStorePeer, shouldUpdatePeers bool
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1151,14 +1144,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed {
shouldStorePeer = true
shouldUpdatePeers = true
}
}
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
@@ -1180,7 +1169,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
peer.UpdateMetaIfNew(ctx, login.Meta)
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, false, err
}
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, false, err
}
@@ -1190,7 +1187,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
if isStatusChanged || shouldStorePeer {
if shouldUpdatePeers {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
@@ -1286,12 +1283,22 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, nil, false, nil
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
if err != nil {
return nil, nil, false, err
}
@@ -1299,32 +1306,16 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, postureChecks, enableSSH, nil
}
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
for _, peerID := range peerGroupIDs {
groupIDsMap[peerID] = struct{}{}
}
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return false, err
}
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
for _, g := range peerGroups {
peerGroupIDs[g.ID] = struct{}{}
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
}
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
if len(policies) == 0 {
return nil, nil
}
@@ -1336,11 +1327,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
continue
}
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
if err != nil {
return nil, err
}
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
}
@@ -1353,29 +1340,19 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
}
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil {
return nil, err
}
for _, sourceGroup := range rule.Sources {
group, ok := sourceGroups[sourceGroup]
if !ok {
return nil, fmt.Errorf("failed to check peer in policy source group")
}
if slices.Contains(group.Peers, peerID) {
return policy.SourcePostureChecks, nil
if slices.Contains(peerGroupIDs, sourceGroup) {
return policy.SourcePostureChecks
}
}
}
return nil, nil
return nil
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO

View File

@@ -1,12 +1,16 @@
package peer
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -162,49 +166,7 @@ type PeerSystemMeta struct { //nolint:revive
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
})
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
})
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
sort.Slice(p.Files, func(i, j int) bool {
return p.Files[i].Path < p.Files[j].Path
})
sort.Slice(other.Files, func(i, j int) bool {
return other.Files[i].Path < other.Files[j].Path
})
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.KernelVersion == other.KernelVersion &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.OSVersion == other.OSVersion &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion &&
p.SystemSerialNumber == other.SystemSerialNumber &&
p.SystemProductName == other.SystemProductName &&
p.SystemManufacturer == other.SystemManufacturer &&
p.Environment.Cloud == other.Environment.Cloud &&
p.Environment.Platform == other.Environment.Platform &&
p.Flags.isEqual(other.Flags) &&
capabilitiesEqual(p.Capabilities, other.Capabilities)
return len(metaDiff(p, other)) == 0
}
func (p PeerSystemMeta) isEmpty() bool {
@@ -296,7 +258,7 @@ func (p *Peer) Copy() *Peer {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
if meta.isEmpty() {
return updated, versionChanged
}
@@ -308,14 +270,121 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged boo
meta.UIVersion = p.Meta.UIVersion
}
if p.Meta.isEqual(meta) {
return updated, versionChanged
oldVersion := p.Meta.WtVersion
diff := metaDiff(p.Meta, meta)
if len(diff) != 0 {
p.Meta = meta
updated = true
}
p.Meta = meta
updated = true
versionInfo := ""
if versionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if len(diff) > 0 || versionChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
}
return updated, versionChanged
}
// metaDiff returns a human-readable list of the fields that differ between the
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
// source of truth for meta comparison: isEqual reports equality as an empty
// diff, so the log line can never disagree with the change decision. Slices are
// cloned before sorting, so callers' meta is not mutated.
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
var diff []string
add := func(field string, oldVal, newVal any) {
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
return diff
}
// sameMultiset reports whether two slices contain the same elements with the
// same multiplicity, ignoring order. The element type is the comparison key, so
// every field participates in equality.
func sameMultiset[T comparable](a, b []T) bool {
if len(a) != len(b) {
return false
}
counts := make(map[T]int, len(a))
for _, v := range a {
counts[v]++
}
for _, v := range b {
counts[v]--
if counts[v] == 0 {
delete(counts, v)
}
}
return len(counts) == 0
}
// GetLastLogin returns the last login time of the peer.
func (p *Peer) GetLastLogin() time.Time {
if p.LastLogin != nil {

View File

@@ -0,0 +1,113 @@
package peer
import (
"net/netip"
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
// map 1:1 to a single diff entry. Today the only such field is Environment, which
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
// entry beyond its single struct field. If you teach metaDiff to explode another
// field into N entries, bump this by N-1; if you collapse a field, lower it.
const metaDiffExtraEntries = 1
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
// values and diffs it against the zero value, then asserts metaDiff emits exactly
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
//
// The expected count is derived from the struct via reflection, so adding a field
// to PeerSystemMeta raises the expectation automatically — but the actual diff
// only grows if metaDiff was taught to compare the new field. A mismatch means
// someone changed the struct without updating metaDiff (or this test's
// extra-entry accounting), which is exactly what we want to catch.
func TestMetaDiff_CoversAllFields(t *testing.T) {
var full PeerSystemMeta
exported := populateAll(t, reflect.ValueOf(&full).Elem())
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
diff := metaDiff(PeerSystemMeta{}, full)
require.Len(t, diff, exported+metaDiffExtraEntries,
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
"diff was: %v", diff)
require.False(t, full.isEqual(PeerSystemMeta{}),
"isEqual must report a fully-populated meta as different from the zero value")
}
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
// compare would not change the diff count. This flips each Flags field on its own
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
// fails here.
func TestFlags_isEqualChecksEveryField(t *testing.T) {
typ := reflect.TypeOf(Flags{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
require.Equal(t, reflect.Bool, f.Type.Kind(),
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
var a, b Flags
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
}
}
// populateAll sets every exported field of the struct to a deterministic non-zero
// value, recursing into nested structs and the element type of struct slices so
// that each leaf differs from zero. It returns the number of exported fields on
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
// settable exported fields and is comparable with ==).
func populateAll(t *testing.T, v reflect.Value) int {
t.Helper()
typ := v.Type()
exported := 0
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath != "" { // unexported
continue
}
exported++
setNonZero(t, v.Field(i))
}
return exported
}
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
// recursing into nested structs and populating one element of slice fields.
func setNonZero(t *testing.T, field reflect.Value) {
t.Helper()
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
return
}
switch field.Kind() {
case reflect.String:
field.SetString("non-zero")
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Float32, reflect.Float64:
field.SetFloat(7)
case reflect.Struct:
populateAll(t, field)
case reflect.Slice:
s := reflect.MakeSlice(field.Type(), 1, 1)
setNonZero(t, s.Index(0))
field.Set(s)
default:
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
}
}

View File

@@ -125,6 +125,7 @@ func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
oidcConfig,
nil,
usersManager,
nil,
realProxyManager,
nil,
)

View File

@@ -140,6 +140,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
oidcConfig,
nil,
usersManager,
nil,
proxyManager,
nil,
)

View File

@@ -21,7 +21,8 @@ AWK_FIRST_FIELD='{print $1}'
fetch_all_tags() {
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
grep -iv 'rc' | \
sed 's/.*\/v//' | \
sort -u -V
return 0

View File

@@ -32,7 +32,8 @@ fetch_current_ports_version() {
fetch_all_tags() {
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
grep -iv 'rc' | \
sed 's/.*\/v//' | \
sort -u -V
return 0

View File

@@ -26,6 +26,10 @@ type Peer struct {
// a gRpc connection stream to the Peer
Stream proto.SignalExchange_ConnectStreamServer
// sendMu serializes writes to Stream. gRPC forbids concurrent SendMsg on
// the same ServerStream, and a peer can be the target of many senders at
// once.
sendMu sync.Mutex
// registration time
RegisteredAt time.Time
@@ -33,6 +37,13 @@ type Peer struct {
Cancel context.CancelFunc
}
// Send writes a message to the peer's stream, serializing concurrent senders.
func (p *Peer) Send(msg *proto.EncryptedMessage) error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.Stream.Send(msg)
}
// NewPeer creates a new instance of a connected Peer
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
return &Peer{

View File

@@ -0,0 +1,67 @@
package server
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/peer"
)
// concurrencyCheckStream records the maximum number of Send calls in flight at
// once. gRPC forbids concurrent SendMsg on the same ServerStream, so a correct
// server must never have more than one in flight per peer.
type concurrencyCheckStream struct {
proto.SignalExchange_ConnectStreamServer
ctx context.Context
inflight atomic.Int32
maxSeen atomic.Int32
}
func (s *concurrencyCheckStream) Send(*proto.EncryptedMessage) error {
n := s.inflight.Add(1)
for {
old := s.maxSeen.Load()
if n <= old || s.maxSeen.CompareAndSwap(old, n) {
break
}
}
// Widen the window so overlapping callers are reliably observed.
time.Sleep(time.Millisecond)
s.inflight.Add(-1)
return nil
}
func (s *concurrencyCheckStream) Context() context.Context { return s.ctx }
// TestForwardMessageToPeerSerializesSend verifies that concurrent forwards to the
// same peer never call Stream.Send concurrently, which would violate the gRPC
// ServerStream contract.
func TestForwardMessageToPeerSerializesSend(t *testing.T) {
s, err := NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
const peerID = "peerX"
stream := &concurrencyCheckStream{ctx: context.Background()}
_, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
require.NoError(t, s.registry.Register(peer.NewPeer(peerID, stream, cancel)))
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s.forwardMessageToPeer(context.Background(), &proto.EncryptedMessage{Key: "sender", RemoteKey: peerID})
}()
}
wg.Wait()
require.Equal(t, int32(1), stream.maxSeen.Load(), "Stream.Send must never run concurrently on the same peer stream")
}

View File

@@ -179,7 +179,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
sendResultChan := make(chan error, 1)
go func() {
select {
case sendResultChan <- dstPeer.Stream.Send(msg):
case sendResultChan <- dstPeer.Send(msg):
return
case <-dstPeer.Stream.Context().Done():
return