mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-19 22:39:59 +00:00
Compare commits
11 Commits
fix/ipv6-a
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
58c79f5878 | ||
|
|
15a0504fb1 | ||
|
|
883a1a8961 | ||
|
|
54192a94b7 | ||
|
|
8511687270 | ||
|
|
35b465fa4a | ||
|
|
fb87f751a5 | ||
|
|
679c7182a4 | ||
|
|
8c031ea6f0 | ||
|
|
60a9544656 | ||
|
|
d3710d4bb2 |
@@ -247,7 +247,7 @@ dockers_v2:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
@@ -295,7 +295,7 @@ dockers_v2:
|
||||
- netbirdio/relay
|
||||
- ghcr.io/netbirdio/relay
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: relay/Dockerfile
|
||||
platforms:
|
||||
@@ -317,7 +317,7 @@ dockers_v2:
|
||||
- netbirdio/signal
|
||||
- ghcr.io/netbirdio/signal
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: signal/Dockerfile
|
||||
platforms:
|
||||
@@ -339,7 +339,7 @@ dockers_v2:
|
||||
- netbirdio/management
|
||||
- ghcr.io/netbirdio/management
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: management/Dockerfile
|
||||
platforms:
|
||||
@@ -361,7 +361,7 @@ dockers_v2:
|
||||
- netbirdio/upload
|
||||
- ghcr.io/netbirdio/upload
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: upload-server/Dockerfile
|
||||
platforms:
|
||||
@@ -383,7 +383,7 @@ dockers_v2:
|
||||
- netbirdio/netbird-server
|
||||
- ghcr.io/netbirdio/netbird-server
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: combined/Dockerfile
|
||||
platforms:
|
||||
@@ -405,7 +405,7 @@ dockers_v2:
|
||||
- netbirdio/reverse-proxy
|
||||
- ghcr.io/netbirdio/reverse-proxy
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: proxy/Dockerfile
|
||||
platforms:
|
||||
|
||||
@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
||||
return "", fmt.Errorf("switch profile failed: %w", err)
|
||||
}
|
||||
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
|
||||
@@ -138,26 +138,23 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
profileName := args[0]
|
||||
|
||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add profile request: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||
@@ -166,7 +163,6 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
|
||||
id := profilemanager.ID(resp.Id)
|
||||
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
||||
return nil
|
||||
|
||||
@@ -330,3 +326,19 @@ func wrapAmbiguityError(err error, handle string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
|
||||
// and returns the new profile's ID. It is the single entry point for profile
|
||||
// creation, shared by `netbird profile add` and the `netbird up --profile
|
||||
// <name>` auto-create path.
|
||||
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
|
||||
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add profile failed: %w", err)
|
||||
}
|
||||
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -111,11 +110,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
// Resolve the active profile's display name via the daemon, which runs
|
||||
// as root and can read the per-user profile files. The local profile
|
||||
// manager only knows the active profile ID, not its display name.
|
||||
profName := getActiveProfileName(ctx)
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
@@ -167,6 +165,25 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// getActiveProfileName asks the daemon for the active profile's display
|
||||
// name. The daemon runs as root and can read the per-user profile files to
|
||||
// resolve the ID to its human-readable name. Returns an empty string on any
|
||||
// error so status output degrades gracefully.
|
||||
func getActiveProfileName(ctx context.Context) string {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return resp.GetProfileName()
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "idle", "connecting", "connected":
|
||||
|
||||
@@ -128,15 +128,9 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
var profileSwitched bool
|
||||
// switch profile if provided
|
||||
if profileName != "" {
|
||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
||||
if err != nil {
|
||||
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
profileSwitched = true
|
||||
}
|
||||
|
||||
@@ -151,6 +145,52 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
||||
}
|
||||
|
||||
// switchOrCreateProfile switches the active profile to the one identified by
|
||||
// handle, creating it first when it does not exist yet. This restores the
|
||||
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
|
||||
// missing profile instead of failing.
|
||||
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
|
||||
resolvedID, err := switchProfile(ctx, handle, username)
|
||||
if err != nil {
|
||||
st, ok := gstatus.FromError(err)
|
||||
if !ok || st.Code() != codes.NotFound {
|
||||
return err
|
||||
}
|
||||
// Don't fail immediately on a create error: a concurrent run may
|
||||
// have created the profile between the NotFound above and this
|
||||
// call, in which case the retried switch still succeeds. Only
|
||||
// surface the create error if the switch also fails.
|
||||
_, createErr := createProfile(ctx, handle, username)
|
||||
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
|
||||
if createErr != nil {
|
||||
return fmt.Errorf("create profile: %w", createErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProfile dials the daemon and creates a new profile with the given
|
||||
// display name, returning its generated ID. Use addProfileOnDaemon directly
|
||||
// when a daemon client is already available to reuse the connection.
|
||||
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
|
||||
}
|
||||
|
||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
||||
// override the default profile filepath if provided
|
||||
if configPath != "" {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -38,11 +39,15 @@ const (
|
||||
// defaultWarningDelayBase is the starting grace window before a
|
||||
// "Nameserver group unreachable" event fires for a group that's
|
||||
// never been healthy and only has overlay upstreams with no
|
||||
// Connected peer. Per-server and overridable; see warningDelayFor.
|
||||
defaultWarningDelayBase = 30 * time.Second
|
||||
// Connected peer. Per-server and overridable via envWarningDelay;
|
||||
// see warningDelay.
|
||||
defaultWarningDelayBase = 60 * time.Second
|
||||
// warningDelayBonusCap caps the route-count bonus added to the
|
||||
// base grace window. See warningDelayFor.
|
||||
// base grace window. See warningDelay.
|
||||
warningDelayBonusCap = 30 * time.Second
|
||||
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
|
||||
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
|
||||
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
|
||||
)
|
||||
|
||||
// errNoUsableNameservers signals that a merged-domain group has no usable
|
||||
@@ -135,7 +140,7 @@ type DefaultServer struct {
|
||||
disableSys bool
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
dnsMuxHandlers []handlerWrapper
|
||||
localResolver *local.Resolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
@@ -199,8 +204,6 @@ type handlerWrapper struct {
|
||||
priority int
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||
|
||||
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||
type DefaultServerConfig struct {
|
||||
WgInterface WGIface
|
||||
@@ -289,7 +292,6 @@ func newDefaultServer(
|
||||
service: dnsService,
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: local.NewResolver(),
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
@@ -298,7 +300,7 @@ func newDefaultServer(
|
||||
hostManager: &noopHostConfigurator{},
|
||||
mgmtCacheResolver: mgmtCacheResolver,
|
||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
warningDelayBase: warningDelayBaseFromEnv(),
|
||||
healthRefresh: make(chan struct{}, 1),
|
||||
}
|
||||
// Wire the local resolver against the peer status recorder so it can
|
||||
@@ -328,7 +330,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||
type routeSettable interface {
|
||||
setSelectedRoutes(func() route.HAMap)
|
||||
}
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
for _, entry := range s.dnsMuxHandlers {
|
||||
if h, ok := entry.handler.(routeSettable); ok {
|
||||
h.setSelectedRoutes(selected)
|
||||
}
|
||||
@@ -978,19 +980,23 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
for _, existing := range s.dnsMuxHandlers {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
existing.handler.Stop()
|
||||
// The local resolver is a persistent singleton shared by every custom
|
||||
// zone and reused across config updates. Its chain registrations are
|
||||
// per-config and must be deregistered, but Stop() cancels its lookup
|
||||
// context (breaking external CNAME-target resolution) and clears its
|
||||
// records, so it must not be torn down here.
|
||||
if existing.handler != s.localResolver {
|
||||
existing.handler.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.ID()] = update
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
s.dnsMuxHandlers = muxUpdates
|
||||
}
|
||||
|
||||
// updateNSGroupStates records the new group set and pokes the refresher.
|
||||
@@ -1154,6 +1160,26 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
|
||||
return false
|
||||
}
|
||||
|
||||
// warningDelayBaseFromEnv returns the base grace window, honoring
|
||||
// envWarningDelay when it holds a valid positive Go duration. Invalid or
|
||||
// non-positive values fall back to defaultWarningDelayBase.
|
||||
func warningDelayBaseFromEnv() time.Duration {
|
||||
val := os.Getenv(envWarningDelay)
|
||||
if val == "" {
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
d, err := time.ParseDuration(val)
|
||||
if err != nil {
|
||||
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
if d <= 0 {
|
||||
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// warningDelay returns the grace window for the given selected-route
|
||||
// count. Scales gently: +1s per 100 routes, capped by
|
||||
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
||||
@@ -1204,7 +1230,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
|
||||
// in more than one handler.
|
||||
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
merged := make(map[netip.AddrPort]UpstreamHealth)
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
for _, entry := range s.dnsMuxHandlers {
|
||||
reporter, ok := entry.handler.(upstreamHealthReporter)
|
||||
if !ok {
|
||||
continue
|
||||
|
||||
@@ -104,19 +104,6 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
@@ -132,22 +119,20 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initUpstreamMap []handlerWrapper
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedUpstreamMap []handlerWrapper
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -169,20 +154,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
@@ -191,10 +173,10 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
@@ -215,15 +197,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
@@ -232,7 +212,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
@@ -240,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -262,7 +242,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -284,7 +264,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -301,42 +281,41 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
expectedUpstreamMap: []handlerWrapper{{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
@@ -393,7 +372,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
@@ -405,14 +384,20 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
for _, expected := range testCase.expectedUpstreamMap {
|
||||
found := false
|
||||
for _, got := range dnsServer.dnsMuxHandlers {
|
||||
if got.domain == expected.domain && got.priority == expected.priority {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -512,8 +497,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
@@ -1029,15 +1014,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
func (m *mockService) DeregisterMux(string) {}
|
||||
|
||||
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
baseMatchHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
baseMatchHandlers := []handlerWrapper{
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"upstream-group2": {
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
@@ -1046,15 +1031,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
baseRootHandlers := registeredHandlerMap{
|
||||
"upstream-root1": {
|
||||
baseRootHandlers := []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
"upstream-root2": {
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
@@ -1063,22 +1048,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
baseMixedHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
baseMixedHandlers := []handlerWrapper{
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"upstream-group2": {
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
},
|
||||
"upstream-other": {
|
||||
{
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
@@ -1089,7 +1074,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialHandlers registeredHandlerMap
|
||||
initialHandlers []handlerWrapper
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[HandlerID]domain
|
||||
description string
|
||||
@@ -1373,32 +1358,38 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
dnsMuxMap: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
dnsMuxHandlers: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Perform the update
|
||||
server.updateMux(tt.updates)
|
||||
|
||||
// Verify the results
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
|
||||
"Number of handlers after update doesn't match expected")
|
||||
|
||||
// Check each expected handler
|
||||
for id, expectedDomain := range tt.expectedHandlers {
|
||||
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
||||
assert.True(t, exists, "Expected handler %s not found", id)
|
||||
if exists {
|
||||
assert.Equal(t, expectedDomain, handler.domain,
|
||||
var found *handlerWrapper
|
||||
for i := range server.dnsMuxHandlers {
|
||||
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
|
||||
found = &server.dnsMuxHandlers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, found, "Expected handler %s not found", id)
|
||||
if found != nil {
|
||||
assert.Equal(t, expectedDomain, found.domain,
|
||||
"Domain mismatch for handler %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected handlers exist
|
||||
for HandlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(HandlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
||||
for _, entry := range server.dnsMuxHandlers {
|
||||
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
|
||||
}
|
||||
|
||||
// Verify the handlerChain state and order
|
||||
@@ -1413,7 +1404,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
// Verify handler exists in mux
|
||||
foundInMux := false
|
||||
for _, muxEntry := range server.dnsMuxMap {
|
||||
for _, muxEntry := range server.dnsMuxHandlers {
|
||||
if chainEntry.Handler == muxEntry.handler &&
|
||||
chainEntry.Priority == muxEntry.priority &&
|
||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||
@@ -1422,12 +1413,108 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, foundInMux,
|
||||
"Handler in chain not found in dnsMuxMap")
|
||||
"Handler in chain not found in dnsMuxHandlers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// chainHasPattern reports whether the handler chain holds an entry registered
|
||||
// for the given fqdn pattern at the given priority.
|
||||
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
|
||||
for _, h := range s.handlerChain.handlers {
|
||||
if h.OrigPattern == pattern && h.Priority == priority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
|
||||
// tracks each (handler, domain) registration independently when one handler
|
||||
// serves multiple zones. Every custom zone is served by the same handler
|
||||
// instance (the local resolver, whose ID is the constant "local-resolver"), so
|
||||
// removing one zone must deregister exactly that zone's chain entry and leave
|
||||
// the others in place. Tracking registrations by handler ID alone collapses all
|
||||
// zones onto one entry, leaving removed zones in the chain to answer
|
||||
// authoritatively with no records.
|
||||
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
|
||||
// One handler serves every custom zone, mirroring s.localResolver.
|
||||
shared := &mockHandler{Id: "local-resolver"}
|
||||
|
||||
server := &DefaultServer{
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Two custom zones under the same handler. The surviving zone is registered
|
||||
// last, mirroring the management emission order.
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
|
||||
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||
"userzone.test should be registered after the first update")
|
||||
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||
"peerzone.test should be registered after the first update")
|
||||
|
||||
// Remove one zone, keep the other.
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||
"peerzone.test should remain after removing userzone.test")
|
||||
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||
"userzone.test handler must be deregistered, not leaked in the chain")
|
||||
}
|
||||
|
||||
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
|
||||
// does not tear down the shared local resolver during reconfiguration. The
|
||||
// resolver is a process-lifetime singleton reused across config updates;
|
||||
// Stop() cancels its lookup context (breaking external CNAME-target
|
||||
// resolution) and clears its records. updateMux must deregister its chain
|
||||
// entries without stopping it. Records surviving a teardown update is the
|
||||
// observable proxy: Stop() would have cleared them.
|
||||
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
|
||||
resolver := local.NewResolver()
|
||||
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
|
||||
Name: "peer.netbird.cloud.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.0.0.1",
|
||||
}))
|
||||
|
||||
server := &DefaultServer{
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
localResolver: resolver,
|
||||
}
|
||||
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
// Remove the zone. The resolver must survive so its records and lookup
|
||||
// context stay intact for the next registration.
|
||||
server.updateMux(nil)
|
||||
|
||||
var response *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
response = m
|
||||
return nil
|
||||
},
|
||||
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
|
||||
|
||||
require.NotNil(t, response, "local resolver should answer after teardown")
|
||||
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
|
||||
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
|
||||
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
|
||||
}
|
||||
|
||||
func TestExtraDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -2049,7 +2136,6 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
groups := []*nbdns.NameServerGroup{
|
||||
@@ -2207,7 +2293,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
|
||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||
// without spinning up real handlers.
|
||||
type healthStubHandler struct {
|
||||
@@ -2283,12 +2369,11 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||
activeRoutes: func() route.HAMap { return fx.active },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
}
|
||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
|
||||
|
||||
fx.server.mux.Lock()
|
||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||
@@ -2395,7 +2480,6 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: 50 * time.Millisecond,
|
||||
@@ -2407,7 +2491,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2444,7 +2528,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
service: NewServiceViaMemory(wgIface),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
extraDomains: map[domain.Domain]int{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
@@ -2459,7 +2542,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2484,6 +2567,32 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||
// comes up and a query succeeds before the grace window elapses. No
|
||||
// warning should ever have fired, and no recovery either.
|
||||
func TestWarningDelayBaseFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
set bool
|
||||
val string
|
||||
want time.Duration
|
||||
}{
|
||||
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
|
||||
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
|
||||
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
|
||||
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
|
||||
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
|
||||
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envWarningDelay, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(envWarningDelay)
|
||||
}
|
||||
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||
@@ -2595,7 +2704,6 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: time.Hour,
|
||||
@@ -2613,7 +2721,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
},
|
||||
}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2640,7 +2748,6 @@ func TestDNSLoopPrevention(t *testing.T) {
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
||||
@@ -443,21 +443,25 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
// A valid response means the upstream is reachable, whatever the Rcode.
|
||||
u.markUpstreamOk(upstream)
|
||||
|
||||
proto := ""
|
||||
if upstreamProto != nil {
|
||||
proto = upstreamProto.protocol
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
|
||||
// refused zones, transient recursion errors), not reachability
|
||||
// problems: fail over for a better answer but keep the upstream healthy.
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
reason := dns.RcodeToString[rm.Rcode]
|
||||
u.markUpstreamFail(upstream, reason)
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
@@ -465,7 +469,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
stripOPT(rm)
|
||||
}
|
||||
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -517,6 +517,78 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
|
||||
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
|
||||
// those are per-question outcomes from a reachable server and must not mark
|
||||
// the upstream unhealthy. Only transport failures (timeouts) do.
|
||||
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
|
||||
a := netip.MustParseAddrPort("192.0.2.10:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.11:53")
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respA mockUpstreamResponse
|
||||
respB mockUpstreamResponse
|
||||
wantHealthy bool
|
||||
}{
|
||||
{
|
||||
name: "both SERVFAIL are reachable",
|
||||
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
wantHealthy: true,
|
||||
},
|
||||
{
|
||||
name: "both REFUSED are reachable",
|
||||
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
wantHealthy: true,
|
||||
},
|
||||
{
|
||||
name: "timeout marks unhealthy",
|
||||
respA: mockUpstreamResponse{err: timeoutErr},
|
||||
respB: mockUpstreamResponse{err: timeoutErr},
|
||||
wantHealthy: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
a.String(): tc.respA,
|
||||
b.String(): tc.respB,
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{a, b})
|
||||
|
||||
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
|
||||
health := resolver.UpstreamHealth()
|
||||
require.Contains(t, health, a, "primary upstream should have a health record")
|
||||
if tc.wantHealthy {
|
||||
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
|
||||
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
|
||||
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
|
||||
} else {
|
||||
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
|
||||
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
|
||||
@@ -33,6 +33,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||
@@ -82,6 +84,9 @@ type ProxyServiceServer struct {
|
||||
// Manager for users
|
||||
usersManager users.Manager
|
||||
|
||||
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
|
||||
idpManager idp.Manager
|
||||
|
||||
// Store for one-time authentication tokens
|
||||
tokenStore *OneTimeTokenStore
|
||||
|
||||
@@ -157,7 +162,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -166,6 +171,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
pkceVerifierStore: pkceStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
idpManager: idpManager,
|
||||
proxyManager: proxyMgr,
|
||||
tokenChecker: tokenChecker,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
@@ -1702,22 +1708,7 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
||||
}
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||
|
||||
// Resolve the principal: when the peer is linked to a user, the human
|
||||
// is the principal so multiple peers owned by the same user share a
|
||||
// single identity. Unlinked peers (machine agents) are their own
|
||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||
// tag spend with — user.Email when linked, peer.Name when not.
|
||||
principalID := peer.ID
|
||||
displayIdentity := peer.Name
|
||||
if peer.UserID != "" {
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
|
||||
|
||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||
@@ -1754,6 +1745,45 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
|
||||
// user or peer ID, and peer name or user email.
|
||||
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
|
||||
// Resolve the principal: when the peer is linked to a user, the human is the
|
||||
// principal so multiple peers owned by the same user share a single
|
||||
// identity. Unlinked peers (machine agents) are their own principal keyed on
|
||||
// peer.ID. displayIdentity is what upstream gateways tag spend with —
|
||||
// user.Email when linked, peer.Name when not.
|
||||
|
||||
// If the peer isn't associated with a user, return the peer info directly.
|
||||
if peer.UserID == "" {
|
||||
return peer.ID, peer.Name
|
||||
}
|
||||
|
||||
// Otherwise, if the peer is linked to a user, the user is the principal and
|
||||
// if an IdP is available, we gather details on the user from it.
|
||||
principalID := peer.UserID
|
||||
displayIdentity := peer.Name
|
||||
// Stored column first (cheap, but often empty for OIDC-provisioned users).
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
// IdP enrichment wins when available — the stored email column is a
|
||||
// best-effort cache and is frequently empty for OIDC users. Enrichment
|
||||
// failures must never fail the RPC; we simply keep the stored/peer identity.
|
||||
if s.idpManager != nil {
|
||||
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
|
||||
displayIdentity = ud.Email
|
||||
} else if uerr != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
|
||||
}
|
||||
}
|
||||
|
||||
return principalID, displayIdentity
|
||||
}
|
||||
|
||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||
// groups. Private services authorise against AccessGroups (empty list fails
|
||||
// closed — Validate() rejects that at save time but the RPC is the security
|
||||
|
||||
@@ -3,14 +3,19 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type mockReverseProxyManager struct {
|
||||
@@ -137,6 +142,52 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
// mockTunnelPeersManager implements only the two peers.Manager methods that
|
||||
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
|
||||
// panics if any unexpected method is invoked).
|
||||
type mockTunnelPeersManager struct {
|
||||
peers.Manager
|
||||
peer *peer.Peer
|
||||
peerErr error
|
||||
groups []*types.Group
|
||||
groupsErr error
|
||||
}
|
||||
|
||||
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
|
||||
return m.peer, m.peerErr
|
||||
}
|
||||
|
||||
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
|
||||
return m.peer, m.groups, m.groupsErr
|
||||
}
|
||||
|
||||
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
|
||||
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
|
||||
// an IdP that knows nothing about the user.
|
||||
type mockTunnelIdpManager struct {
|
||||
idp.Manager
|
||||
email string
|
||||
hasData bool
|
||||
err error
|
||||
gotCalls int
|
||||
gotMeta []idp.AppMetadata
|
||||
}
|
||||
|
||||
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
|
||||
m.gotCalls++
|
||||
m.gotMeta = append(m.gotMeta, meta)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
if !m.hasData {
|
||||
// This might not be a thing any of the actual IDP implementations do,
|
||||
// i.e. return a nil value with no error, but it seems valuable to test
|
||||
// that behavior here.
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return &idp.UserData{ID: userID, Email: m.email}, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -354,6 +405,163 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
|
||||
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
|
||||
// (IdP email -> stored User.Email -> peer.Name).
|
||||
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
|
||||
const (
|
||||
domain = "app.example.com"
|
||||
accountID = "account1"
|
||||
peerID = "peer1"
|
||||
peerName = "peer-display-name"
|
||||
userID = "user1"
|
||||
)
|
||||
|
||||
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
|
||||
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
peerUserID string
|
||||
storedUsers map[string]*types.User
|
||||
storedErr error
|
||||
noIdP bool
|
||||
idpEmail string
|
||||
idpHasData bool
|
||||
idpErr error
|
||||
expectEmail string
|
||||
expectUserID string
|
||||
expectIdPHit bool
|
||||
}{
|
||||
{
|
||||
name: "idp email wins over stored email",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp returns empty email",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "",
|
||||
idpHasData: true,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp has no data",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpHasData: false,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp errors",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpErr: errors.New("idp unreachable"),
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when no idp manager",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
noIdP: true,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
},
|
||||
{
|
||||
name: "idp email when stored email is empty",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUserNoEmail,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "idp email when stored user missing keeps peer.UserID as principal",
|
||||
peerUserID: userID,
|
||||
storedUsers: map[string]*types.User{},
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "unlinked peer uses peer name and never consults idp",
|
||||
peerUserID: "",
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: peerName,
|
||||
expectUserID: peerID,
|
||||
expectIdPHit: false,
|
||||
},
|
||||
{
|
||||
name: "linked peer with empty stored email and no idp falls back to peer name",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUserNoEmail,
|
||||
noIdP: true,
|
||||
expectEmail: peerName,
|
||||
expectUserID: userID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := &service.Service{Domain: domain, AccountID: accountID}
|
||||
server := &ProxyServiceServer{
|
||||
serviceManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
|
||||
},
|
||||
peersManager: &mockTunnelPeersManager{
|
||||
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
|
||||
},
|
||||
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
|
||||
}
|
||||
|
||||
var idpMock *mockTunnelIdpManager
|
||||
if !tt.noIdP {
|
||||
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
|
||||
server.idpManager = idpMock
|
||||
}
|
||||
|
||||
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
|
||||
Domain: domain,
|
||||
TunnelIp: "100.64.0.1",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.True(t, resp.GetValid(), "expected access granted")
|
||||
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
|
||||
assert.Equal(t, tt.expectUserID, resp.GetUserId())
|
||||
|
||||
if idpMock != nil {
|
||||
if tt.expectIdPHit {
|
||||
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
|
||||
require.Len(t, idpMock.gotMeta, 1)
|
||||
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
|
||||
} else {
|
||||
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -217,6 +217,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
usersManager,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
|
||||
@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
|
||||
@@ -982,8 +982,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
var peer *nbpeer.Peer
|
||||
var updated, versionChanged, ipv6CapabilityChanged bool
|
||||
var err error
|
||||
var postureChecks []*posture.Checks
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -1011,13 +1009,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
|
||||
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
|
||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
@@ -1025,11 +1018,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -1037,6 +1025,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
@@ -1047,9 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -1124,7 +1117,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var shouldStorePeer bool
|
||||
var shouldStorePeer, shouldUpdatePeers bool
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1151,14 +1144,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
shouldUpdatePeers = true
|
||||
}
|
||||
}
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.SSHKey != login.SSHKey {
|
||||
peer.SSHKey = login.SSHKey
|
||||
shouldStorePeer = true
|
||||
@@ -1180,7 +1169,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
|
||||
peer.UpdateMetaIfNew(ctx, login.Meta)
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
@@ -1190,7 +1187,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
if isStatusChanged || shouldStorePeer {
|
||||
if shouldUpdatePeers {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
@@ -1286,12 +1283,22 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
||||
return network, nil, false, nil
|
||||
}
|
||||
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
|
||||
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
@@ -1299,32 +1306,16 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
||||
return network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
|
||||
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
|
||||
for _, peerID := range peerGroupIDs {
|
||||
groupIDsMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
|
||||
for _, g := range peerGroups {
|
||||
peerGroupIDs[g.ID] = struct{}{}
|
||||
}
|
||||
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
|
||||
if len(policies) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1336,11 +1327,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
continue
|
||||
}
|
||||
|
||||
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
|
||||
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
||||
}
|
||||
|
||||
@@ -1353,29 +1340,19 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
}
|
||||
|
||||
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
|
||||
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
|
||||
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group, ok := sourceGroups[sourceGroup]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to check peer in policy source group")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return policy.SourcePostureChecks, nil
|
||||
if slices.Contains(peerGroupIDs, sourceGroup) {
|
||||
return policy.SourcePostureChecks
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
@@ -162,49 +166,7 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
|
||||
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
|
||||
})
|
||||
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
|
||||
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
|
||||
})
|
||||
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
|
||||
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
|
||||
})
|
||||
if !equalNetworkAddresses {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Slice(p.Files, func(i, j int) bool {
|
||||
return p.Files[i].Path < p.Files[j].Path
|
||||
})
|
||||
sort.Slice(other.Files, func(i, j int) bool {
|
||||
return other.Files[i].Path < other.Files[j].Path
|
||||
})
|
||||
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
|
||||
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
|
||||
})
|
||||
if !equalFiles {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.Hostname == other.Hostname &&
|
||||
p.GoOS == other.GoOS &&
|
||||
p.Kernel == other.Kernel &&
|
||||
p.KernelVersion == other.KernelVersion &&
|
||||
p.Core == other.Core &&
|
||||
p.Platform == other.Platform &&
|
||||
p.OS == other.OS &&
|
||||
p.OSVersion == other.OSVersion &&
|
||||
p.WtVersion == other.WtVersion &&
|
||||
p.UIVersion == other.UIVersion &&
|
||||
p.SystemSerialNumber == other.SystemSerialNumber &&
|
||||
p.SystemProductName == other.SystemProductName &&
|
||||
p.SystemManufacturer == other.SystemManufacturer &&
|
||||
p.Environment.Cloud == other.Environment.Cloud &&
|
||||
p.Environment.Platform == other.Environment.Platform &&
|
||||
p.Flags.isEqual(other.Flags) &&
|
||||
capabilitiesEqual(p.Capabilities, other.Capabilities)
|
||||
return len(metaDiff(p, other)) == 0
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEmpty() bool {
|
||||
@@ -296,7 +258,7 @@ func (p *Peer) Copy() *Peer {
|
||||
|
||||
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||
// returns true if meta was updated, false otherwise
|
||||
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||
if meta.isEmpty() {
|
||||
return updated, versionChanged
|
||||
}
|
||||
@@ -308,14 +270,121 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged boo
|
||||
meta.UIVersion = p.Meta.UIVersion
|
||||
}
|
||||
|
||||
if p.Meta.isEqual(meta) {
|
||||
return updated, versionChanged
|
||||
oldVersion := p.Meta.WtVersion
|
||||
|
||||
diff := metaDiff(p.Meta, meta)
|
||||
if len(diff) != 0 {
|
||||
p.Meta = meta
|
||||
updated = true
|
||||
}
|
||||
p.Meta = meta
|
||||
updated = true
|
||||
|
||||
versionInfo := ""
|
||||
if versionChanged {
|
||||
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
|
||||
}
|
||||
|
||||
if len(diff) > 0 || versionChanged {
|
||||
log.WithContext(ctx).
|
||||
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
|
||||
}
|
||||
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
// metaDiff returns a human-readable list of the fields that differ between the
|
||||
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
|
||||
// source of truth for meta comparison: isEqual reports equality as an empty
|
||||
// diff, so the log line can never disagree with the change decision. Slices are
|
||||
// cloned before sorting, so callers' meta is not mutated.
|
||||
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
var diff []string
|
||||
add := func(field string, oldVal, newVal any) {
|
||||
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
}
|
||||
|
||||
if oldMeta.Hostname != newMeta.Hostname {
|
||||
add("hostname", oldMeta.Hostname, newMeta.Hostname)
|
||||
}
|
||||
if oldMeta.GoOS != newMeta.GoOS {
|
||||
add("goos", oldMeta.GoOS, newMeta.GoOS)
|
||||
}
|
||||
if oldMeta.Kernel != newMeta.Kernel {
|
||||
add("kernel", oldMeta.Kernel, newMeta.Kernel)
|
||||
}
|
||||
if oldMeta.KernelVersion != newMeta.KernelVersion {
|
||||
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
|
||||
}
|
||||
if oldMeta.Core != newMeta.Core {
|
||||
add("core", oldMeta.Core, newMeta.Core)
|
||||
}
|
||||
if oldMeta.Platform != newMeta.Platform {
|
||||
add("platform", oldMeta.Platform, newMeta.Platform)
|
||||
}
|
||||
if oldMeta.OS != newMeta.OS {
|
||||
add("os", oldMeta.OS, newMeta.OS)
|
||||
}
|
||||
if oldMeta.OSVersion != newMeta.OSVersion {
|
||||
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
|
||||
}
|
||||
if oldMeta.WtVersion != newMeta.WtVersion {
|
||||
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
|
||||
}
|
||||
if oldMeta.UIVersion != newMeta.UIVersion {
|
||||
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
|
||||
}
|
||||
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
|
||||
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
|
||||
}
|
||||
if oldMeta.SystemProductName != newMeta.SystemProductName {
|
||||
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
|
||||
}
|
||||
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
|
||||
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
|
||||
}
|
||||
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
|
||||
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
|
||||
}
|
||||
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
|
||||
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
|
||||
}
|
||||
if !oldMeta.Flags.isEqual(newMeta.Flags) {
|
||||
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
|
||||
}
|
||||
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
// sameMultiset reports whether two slices contain the same elements with the
|
||||
// same multiplicity, ignoring order. The element type is the comparison key, so
|
||||
// every field participates in equality.
|
||||
func sameMultiset[T comparable](a, b []T) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
counts := make(map[T]int, len(a))
|
||||
for _, v := range a {
|
||||
counts[v]++
|
||||
}
|
||||
for _, v := range b {
|
||||
counts[v]--
|
||||
if counts[v] == 0 {
|
||||
delete(counts, v)
|
||||
}
|
||||
}
|
||||
return len(counts) == 0
|
||||
}
|
||||
|
||||
// GetLastLogin returns the last login time of the peer.
|
||||
func (p *Peer) GetLastLogin() time.Time {
|
||||
if p.LastLogin != nil {
|
||||
|
||||
113
management/server/peer/peer_metadiff_test.go
Normal file
113
management/server/peer/peer_metadiff_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
|
||||
// map 1:1 to a single diff entry. Today the only such field is Environment, which
|
||||
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
|
||||
// entry beyond its single struct field. If you teach metaDiff to explode another
|
||||
// field into N entries, bump this by N-1; if you collapse a field, lower it.
|
||||
const metaDiffExtraEntries = 1
|
||||
|
||||
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
|
||||
// values and diffs it against the zero value, then asserts metaDiff emits exactly
|
||||
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
|
||||
//
|
||||
// The expected count is derived from the struct via reflection, so adding a field
|
||||
// to PeerSystemMeta raises the expectation automatically — but the actual diff
|
||||
// only grows if metaDiff was taught to compare the new field. A mismatch means
|
||||
// someone changed the struct without updating metaDiff (or this test's
|
||||
// extra-entry accounting), which is exactly what we want to catch.
|
||||
func TestMetaDiff_CoversAllFields(t *testing.T) {
|
||||
var full PeerSystemMeta
|
||||
exported := populateAll(t, reflect.ValueOf(&full).Elem())
|
||||
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
|
||||
|
||||
diff := metaDiff(PeerSystemMeta{}, full)
|
||||
|
||||
require.Len(t, diff, exported+metaDiffExtraEntries,
|
||||
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
|
||||
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
|
||||
"diff was: %v", diff)
|
||||
|
||||
require.False(t, full.isEqual(PeerSystemMeta{}),
|
||||
"isEqual must report a fully-populated meta as different from the zero value")
|
||||
}
|
||||
|
||||
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
|
||||
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
|
||||
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
|
||||
// compare would not change the diff count. This flips each Flags field on its own
|
||||
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
|
||||
// fails here.
|
||||
func TestFlags_isEqualChecksEveryField(t *testing.T) {
|
||||
typ := reflect.TypeOf(Flags{})
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
require.Equal(t, reflect.Bool, f.Type.Kind(),
|
||||
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
|
||||
|
||||
var a, b Flags
|
||||
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
|
||||
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// populateAll sets every exported field of the struct to a deterministic non-zero
|
||||
// value, recursing into nested structs and the element type of struct slices so
|
||||
// that each leaf differs from zero. It returns the number of exported fields on
|
||||
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
|
||||
// settable exported fields and is comparable with ==).
|
||||
func populateAll(t *testing.T, v reflect.Value) int {
|
||||
t.Helper()
|
||||
|
||||
typ := v.Type()
|
||||
exported := 0
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
if f.PkgPath != "" { // unexported
|
||||
continue
|
||||
}
|
||||
exported++
|
||||
setNonZero(t, v.Field(i))
|
||||
}
|
||||
return exported
|
||||
}
|
||||
|
||||
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
|
||||
// recursing into nested structs and populating one element of slice fields.
|
||||
func setNonZero(t *testing.T, field reflect.Value) {
|
||||
t.Helper()
|
||||
|
||||
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
|
||||
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
|
||||
return
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString("non-zero")
|
||||
case reflect.Bool:
|
||||
field.SetBool(true)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.SetInt(7)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.SetUint(7)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.SetFloat(7)
|
||||
case reflect.Struct:
|
||||
populateAll(t, field)
|
||||
case reflect.Slice:
|
||||
s := reflect.MakeSlice(field.Type(), 1, 1)
|
||||
setNonZero(t, s.Index(0))
|
||||
field.Set(s)
|
||||
default:
|
||||
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
|
||||
}
|
||||
}
|
||||
@@ -125,6 +125,7 @@ func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
realProxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -140,6 +140,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
proxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -21,7 +21,8 @@ AWK_FIRST_FIELD='{print $1}'
|
||||
|
||||
fetch_all_tags() {
|
||||
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
|
||||
grep -iv 'rc' | \
|
||||
sed 's/.*\/v//' | \
|
||||
sort -u -V
|
||||
return 0
|
||||
|
||||
@@ -32,7 +32,8 @@ fetch_current_ports_version() {
|
||||
fetch_all_tags() {
|
||||
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
|
||||
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
|
||||
grep -iv 'rc' | \
|
||||
sed 's/.*\/v//' | \
|
||||
sort -u -V
|
||||
return 0
|
||||
|
||||
@@ -26,6 +26,10 @@ type Peer struct {
|
||||
|
||||
// a gRpc connection stream to the Peer
|
||||
Stream proto.SignalExchange_ConnectStreamServer
|
||||
// sendMu serializes writes to Stream. gRPC forbids concurrent SendMsg on
|
||||
// the same ServerStream, and a peer can be the target of many senders at
|
||||
// once.
|
||||
sendMu sync.Mutex
|
||||
|
||||
// registration time
|
||||
RegisteredAt time.Time
|
||||
@@ -33,6 +37,13 @@ type Peer struct {
|
||||
Cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Send writes a message to the peer's stream, serializing concurrent senders.
|
||||
func (p *Peer) Send(msg *proto.EncryptedMessage) error {
|
||||
p.sendMu.Lock()
|
||||
defer p.sendMu.Unlock()
|
||||
return p.Stream.Send(msg)
|
||||
}
|
||||
|
||||
// NewPeer creates a new instance of a connected Peer
|
||||
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
|
||||
return &Peer{
|
||||
|
||||
67
signal/server/concurrent_send_test.go
Normal file
67
signal/server/concurrent_send_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/peer"
|
||||
)
|
||||
|
||||
// concurrencyCheckStream records the maximum number of Send calls in flight at
|
||||
// once. gRPC forbids concurrent SendMsg on the same ServerStream, so a correct
|
||||
// server must never have more than one in flight per peer.
|
||||
type concurrencyCheckStream struct {
|
||||
proto.SignalExchange_ConnectStreamServer
|
||||
ctx context.Context
|
||||
inflight atomic.Int32
|
||||
maxSeen atomic.Int32
|
||||
}
|
||||
|
||||
func (s *concurrencyCheckStream) Send(*proto.EncryptedMessage) error {
|
||||
n := s.inflight.Add(1)
|
||||
for {
|
||||
old := s.maxSeen.Load()
|
||||
if n <= old || s.maxSeen.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Widen the window so overlapping callers are reliably observed.
|
||||
time.Sleep(time.Millisecond)
|
||||
s.inflight.Add(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *concurrencyCheckStream) Context() context.Context { return s.ctx }
|
||||
|
||||
// TestForwardMessageToPeerSerializesSend verifies that concurrent forwards to the
|
||||
// same peer never call Stream.Send concurrently, which would violate the gRPC
|
||||
// ServerStream contract.
|
||||
func TestForwardMessageToPeerSerializesSend(t *testing.T) {
|
||||
s, err := NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
const peerID = "peerX"
|
||||
stream := &concurrencyCheckStream{ctx: context.Background()}
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
require.NoError(t, s.registry.Register(peer.NewPeer(peerID, stream, cancel)))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.forwardMessageToPeer(context.Background(), &proto.EncryptedMessage{Key: "sender", RemoteKey: peerID})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, int32(1), stream.maxSeen.Load(), "Stream.Send must never run concurrently on the same peer stream")
|
||||
}
|
||||
@@ -179,7 +179,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
|
||||
sendResultChan := make(chan error, 1)
|
||||
go func() {
|
||||
select {
|
||||
case sendResultChan <- dstPeer.Stream.Send(msg):
|
||||
case sendResultChan <- dstPeer.Send(msg):
|
||||
return
|
||||
case <-dstPeer.Stream.Context().Done():
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user