Compare commits

..

10 Commits

Author SHA1 Message Date
Viktor Liu
58c79f5878 [client] Fix DNS custom zone teardown: handler leak and external CNAME resolution (#6445) 2026-06-19 17:33:09 +02:00
Viktor Liu
15a0504fb1 [client] Treat answering upstreams as reachable and widen DNS health grace window (#6453) 2026-06-19 17:32:49 +02:00
Riccardo Manfrin
883a1a8961 [client] Fix profile regressions in up --profile and status (#6479)
* Restores behavior to create profile if not there on Up

* Allows to restore nerbird status showing of the profile name

* [client] Reduce upFunc cognitive complexity

Extract the profile switch/auto-create logic from upFunc into a dedicated
switchOrCreateProfile helper. The inlined NotFound-retry branch pushed
upFunc over SonarCloud's cognitive complexity threshold (S3776).
No behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* [client] Make up --profile auto-create idempotent under concurrent runs

Don't fail switchOrCreateProfile on a createProfile error: a concurrent
run may create the profile between the NotFound check and our create
call. Retry the switch regardless and only surface the create error if
the switch also fails. Addresses CodeRabbit race-condition feedback.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* Share createProfile with addProfileFunc

* But allow conn reusage

* moves switchOrCreateProfile to where it's used

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-19 16:23:51 +02:00
Maycon Santos
54192a94b7 [misc] handle release candidates when fetching tags in FreeBSD port scripts (#6480)
* [misc] Exclude release candidates when fetching tags in FreeBSD port scripts
2026-06-19 14:10:43 +02:00
Pascal Fischer
8511687270 [management] log peer meta diff (#6468) 2026-06-19 13:30:52 +02:00
Pascal Fischer
35b465fa4a [management] reduce sync and login transaction (#6472) 2026-06-19 11:43:01 +02:00
Brad Ison
fb87f751a5 [management] Fetch complete user data in ValidateTunnelPeer (#6457)
* [management] Fetch complete user data in ValidateTunnelPeer

Previously the `ValidateTunnelPeer` method used by the ProxyService
would fetch user information from the database if the connected peer
was associated with a user ID, but it would not consult the IdP data
for cached info from JWT claims like email.  This caused the value of
the injected `X-Netbird-User` header to always display the peer ID and
never the user email associated with the peer as expected.

This change adds an optional IdP manager to the ProxyService and
fetches the complete user data from it if present.

* [management] Refactor ValidateTunnelPeer principal info gathering

This refactors the gathering of info on proxy tunnel peer principals
into its own method to keep the complexity down and make Sonar happy.
2026-06-19 11:39:21 +02:00
Maycon Santos
679c7182a4 [misc] Remove version prefix v docker tags (#6471) 2026-06-18 22:34:24 +02:00
Pascal Fischer
8c031ea6f0 [management] remove db calls in nested loops (#6470) 2026-06-18 22:12:59 +02:00
Pascal Fischer
60a9544656 [management] pass meta update for browser clients (#6465) 2026-06-18 17:22:42 +02:00
46 changed files with 1592 additions and 1177 deletions

View File

@@ -247,7 +247,7 @@ dockers_v2:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile
extra_files:
@@ -295,7 +295,7 @@ dockers_v2:
- netbirdio/relay
- ghcr.io/netbirdio/relay
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile
platforms:
@@ -317,7 +317,7 @@ dockers_v2:
- netbirdio/signal
- ghcr.io/netbirdio/signal
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile
platforms:
@@ -339,7 +339,7 @@ dockers_v2:
- netbirdio/management
- ghcr.io/netbirdio/management
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile
platforms:
@@ -361,7 +361,7 @@ dockers_v2:
- netbirdio/upload
- ghcr.io/netbirdio/upload
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile
platforms:
@@ -383,7 +383,7 @@ dockers_v2:
- netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile
platforms:
@@ -405,7 +405,7 @@ dockers_v2:
- netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy
tags:
- "v{{ .Version }}"
- "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile
platforms:

View File

@@ -151,9 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, c.recorder)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -186,9 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, c.recorder)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
}
// Stop the internal client and free the resources

View File

@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
Username: &username,
})
if err != nil {
return "", fmt.Errorf("switch profile failed: %v", err)
return "", fmt.Errorf("switch profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil

View File

@@ -138,26 +138,23 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
return err
}
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
if err != nil {
return fmt.Errorf("add profile request: %w", err)
return err
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
@@ -166,7 +163,6 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
@@ -330,3 +326,19 @@ func wrapAmbiguityError(err error, handle string) error {
}
return err
}
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
// and returns the new profile's ID. It is the single entry point for profile
// creation, shared by `netbird profile add` and the `netbird up --profile
// <name>` auto-create path.
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
ProfileName: profileName,
Username: username,
})
if err != nil {
return "", fmt.Errorf("add profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil
}

View File

@@ -20,7 +20,6 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
@@ -262,46 +261,17 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
return prefix + upper
}
// DialClientGRPCServer returns client connection to the daemon server. It waits
// (up to the timeout) for the daemon to become reachable so an `up` issued right
// after `service start` tolerates the startup race. Instead of grpc's blocking
// dial — whose raw "transport failed" retry warnings are silenced by the logger
// config — we drive the wait ourselves and emit one clean line per failed attempt.
// DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
conn, err := grpc.DialContext(
return grpc.DialContext(
ctx,
strings.TrimPrefix(addr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
return nil, err
}
conn.Connect()
for {
state := conn.GetState()
if state == connectivity.Ready {
return conn, nil
}
// Log only once the connection has actually failed — not during the
// brief Idle/Connecting phase on a healthy daemon (avoids a spurious
// line + wait when the daemon is already up).
if state == connectivity.TransientFailure {
log.Infof("waiting for the netbird daemon to become available at %s...", addr)
}
// Wake on the next state change, but at least every second so a stuck
// TransientFailure re-logs at a steady cadence until the timeout.
waitCtx, waitCancel := context.WithTimeout(ctx, time.Second)
conn.WaitForStateChange(waitCtx, state)
waitCancel()
if ctx.Err() != nil {
_ = conn.Close()
return nil, fmt.Errorf("daemon not reachable at %s: %w", addr, ctx.Err())
}
}
}
// WithBackOff execute function in backoff cycle.

View File

@@ -11,7 +11,6 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
@@ -111,11 +110,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
// Resolve the active profile's display name via the daemon, which runs
// as root and can read the per-user profile files. The local profile
// manager only knows the active profile ID, not its display name.
profName := getActiveProfileName(ctx)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
@@ -167,6 +165,25 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
return resp, nil
}
// getActiveProfileName asks the daemon for the active profile's display
// name. The daemon runs as root and can read the per-user profile files to
// resolve the ID to its human-readable name. Returns an empty string on any
// error so status output degrades gracefully.
func getActiveProfileName(ctx context.Context) string {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return ""
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
if err != nil {
return ""
}
return resp.GetProfileName()
}
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected":

View File

@@ -128,15 +128,9 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
profileSwitched = true
}
@@ -151,6 +145,52 @@ func upFunc(cmd *cobra.Command, args []string) error {
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
}
// switchOrCreateProfile switches the active profile to the one identified by
// handle, creating it first when it does not exist yet. This restores the
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
// missing profile instead of failing.
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
st, ok := gstatus.FromError(err)
if !ok || st.Code() != codes.NotFound {
return err
}
// Don't fail immediately on a create error: a concurrent run may
// have created the profile between the NotFound above and this
// call, in which case the retried switch still succeeds. Only
// surface the create error if the switch also fails.
_, createErr := createProfile(ctx, handle, username)
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
if createErr != nil {
return fmt.Errorf("create profile: %w", createErr)
}
return err
}
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return err
}
return nil
}
// createProfile dials the daemon and creates a new profile with the given
// display name, returning its generated ID. Use addProfileOnDaemon directly
// when a daemon client is already available to reuse the connection.
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
}
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
// override the default profile filepath if provided
if configPath != "" {
@@ -201,10 +241,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, r)
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(config, nil, util.FindFirstLogPath(logFiles))
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {

View File

@@ -264,24 +264,32 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
client := internal.NewConnectClient(ctx, c.recorder)
client := internal.NewConnectClient(ctx, c.config, c.recorder)
client.SetSyncResponsePersistence(true)
// The supervisor owns the run; we wait until it is established, ends with a
// startup error (permanent backoff err), or startCtx expires.
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
client.RunAsync(c.config, nil)
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run, ""); err != nil {
clientErr <- err
}
}()
if err := client.WaitEstablishedOrDone(startCtx); err != nil {
// Either startCtx expired while connecting, or the run ended before it
// established. Cancel the client context before stopping: Engine.Start
// blocks on the signal stream while holding the engine mutex and only
// unblocks on cancellation. Stopping first would deadlock on that mutex.
select {
case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the
// signal stream while holding the engine mutex and only unblocks on
// cancellation. Stopping first would deadlock on that mutex.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after startup failure. Stop error: %w. Startup: %w", stopErr, err)
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
}
c.connect = client

View File

@@ -18,7 +18,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -49,23 +48,13 @@ import (
"github.com/netbirdio/netbird/version"
)
// androidMobileDep is set on Android to inject the MobileDependency for runs
// started through the generic entry points (Run/RunAsync, e.g. embed.Client).
// nil on other platforms, where the dependency is empty.
var androidMobileDep func(config *profilemanager.Config) MobileDependency
// mobileDependency returns the MobileDependency for a run started via the
// generic entry points. On Android the androidMobileDep provider supplies
// platform stubs (or real implementations); elsewhere it is empty.
func (c *ConnectClient) mobileDependency(config *profilemanager.Config) MobileDependency {
if androidMobileDep != nil {
return androidMobileDep(config)
}
return MobileDependency{}
}
// androidRunOverride is set on Android to inject mobile dependencies
// when using embed.Client (which calls Run() with empty MobileDependency).
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
@@ -74,62 +63,35 @@ type ConnectClient struct {
updateManager *updater.Manager
persistSyncResponse bool
// sup serializes all start/stop requests so two lifecycle operations can
// never overlap. See connect_lifecycle.go.
sup *supervisor
}
func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
) *ConnectClient {
c := &ConnectClient{
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
c.sup = newSupervisor(ctx, c.run)
return c
}
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
c.updateManager = um
}
// Run with main logic. md carries optional gRPC metadata (e.g. the UI
// user-agent) to forward to the management/signal services; nil when none.
func (c *ConnectClient) Run(config *profilemanager.Config, md metadata.MD, logPath string) error {
return c.sup.start(config, md, c.mobileDependency(config), logPath)
}
// RunAsync starts a client run without blocking. Used by the daemon and embed,
// which drive the lifecycle through the supervisor rather than blocking on Run;
// they then wait for the outcome via WaitEstablishedOrDone. The run's lifecycle
// channels are created and owned by the supervisor — callers never hold them.
func (c *ConnectClient) RunAsync(config *profilemanager.Config, md metadata.MD) {
c.sup.startAsync(config, md, c.mobileDependency(config), "", nil)
}
// Restart atomically stops any in-flight run and starts a fresh one with the
// given config. The stop+start happens as a single supervisor operation, so no
// other lifecycle request can interleave between them — used for explicit
// restarts (e.g. an MDM policy change) that must not expose a "stopped" window.
func (c *ConnectClient) Restart(config *profilemanager.Config, md metadata.MD) {
c.sup.restartAsync(config, md, c.mobileDependency(config), "")
}
// WaitEstablishedOrDone blocks until the in-flight run becomes established (nil),
// ends before that (the run error, or a sentinel on a clean stop), or ctx is
// cancelled. Returns errNoRunInFlight if no run is in flight. Wraps the wait on
// the supervisor-owned channels so callers never touch them directly.
func (c *ConnectClient) WaitEstablishedOrDone(ctx context.Context) error {
return c.sup.waitEstablishedOrDone(ctx)
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
if androidRunOverride != nil {
return androidRunOverride(c, runningChan, logPath)
}
return c.run(MobileDependency{}, runningChan, logPath)
}
// RunOnAndroid with main logic on mobile system
func (c *ConnectClient) RunOnAndroid(
config *profilemanager.Config,
tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
@@ -148,11 +110,10 @@ func (c *ConnectClient) RunOnAndroid(
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.sup.start(config, nil, mobileDependency, "")
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) RunOniOS(
config *profilemanager.Config,
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
@@ -170,12 +131,10 @@ func (c *ConnectClient) RunOniOS(
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.sup.start(config, nil, mobileDependency, logFilePath)
return c.run(mobileDependency, nil, logFilePath)
}
// run executes a single client run. runCtx is owned by the supervisor: cancelling
// it tears the run down (it is the parent of the per-attempt engine context).
func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Config, mobileDependency MobileDependency, connEstablishedChan chan struct{}, logPath string) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -239,18 +198,18 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
}()
wrapErr := state.Wrap
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
return wrapErr(err)
}
var mgmTlsEnabled bool
if config.ManagementURL.Scheme == "https" {
if c.config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true
}
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
if err != nil {
return err
}
@@ -284,13 +243,13 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
if runCtx.Err() != nil {
if c.ctx.Err() != nil {
return nil
}
state.Set(StatusConnecting)
engineCtx, cancel := context.WithCancel(runCtx)
engineCtx, cancel := context.WithCancel(c.ctx)
defer func() {
_, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
@@ -298,8 +257,8 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
cancel()
}()
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
@@ -316,7 +275,7 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
}
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() {
if err = mgmClient.Close(); err != nil {
log.Warnf("failed to close the Management service client %v", err)
@@ -325,14 +284,13 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
loginStarted := time.Now()
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, config)
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
if err != nil {
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
// No teardown needed: login fails before the engine is started
// (engine.Start is below), so there is nothing running to stop.
_ = c.Stop()
return backoff.Permanent(wrapErr(err)) // unrecoverable error
}
return wrapErr(err)
@@ -386,7 +344,7 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
}
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, logPath)
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -430,7 +388,7 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
c.engine = engine
c.engineMutex.Unlock()
if err := engine.Start(loginResp.GetNetbirdConfig(), config.ManagementURL); err != nil {
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
@@ -438,13 +396,12 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
// The supervisor owns connEstablishedChan and it is always present. Guard
// against a double close: operation re-runs on ErrResetConnection retries
// within the same run, and the channel is closed only on the first connect.
select {
case <-connEstablishedChan:
default:
close(connEstablishedChan)
if runningChan != nil {
select {
case <-runningChan:
default:
close(runningChan)
}
}
<-engineCtx.Done()
@@ -453,12 +410,14 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
c.engine = nil
c.engineMutex.Unlock()
// Always tear the engine down once its context is cancelled. engine.Stop
// is nil-guarded per component, so calling it unconditionally is safe and
// avoids both the data race on engine.wgInterface and skipping teardown
// when the interface was never brought up (e.g. a mid-start failure).
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
// todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.statusRecorder.ClientTeardown()
@@ -478,9 +437,8 @@ func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Confi
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// Login failed permanently: the engine was never started, so there
// is nothing to tear down — just record that a login is needed.
state.Set(StatusNeedsLogin)
_ = c.Stop()
}
return err
}
@@ -501,22 +459,6 @@ func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
return relayCfg.GetUrls(), token
}
// ConnectionRunning reports whether a connection run is currently in flight
// (connecting, connected, or reconnecting). Answered by the supervisor via a
// serialized query, so it settles behind an in-flight stop. Distinct from
// ServiceRunning, which reports whether the service itself is alive.
func (c *ConnectClient) ConnectionRunning() bool {
return c.sup.isRunning()
}
// ServiceRunning reports whether the client's lifecycle supervisor is alive and
// able to accept start/stop commands — i.e. its context has not been cancelled
// (the daemon is not shutting down). Independent of whether a connection run is
// up (that is ConnectionRunning).
func (c *ConnectClient) ServiceRunning() bool {
return c.sup.ctx.Err() == nil
}
func (c *ConnectClient) Engine() *Engine {
if c == nil {
return nil
@@ -573,10 +515,14 @@ func (c *ConnectClient) Status() StatusType {
return status
}
// Stop serializes a stop request through the lifecycle supervisor and blocks
// until the in-flight run is fully torn down.
func (c *ConnectClient) Stop() error {
return c.sup.stop()
engine := c.Engine()
if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
}
return nil
}
// SetSyncResponsePersistence enables or disables sync response persistence.

View File

@@ -7,7 +7,6 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
@@ -60,17 +59,19 @@ var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
var _ dns.ReadyListener = noopDnsReadyListener{}
func init() {
// Wire up the default MobileDependency provider so embed.Client.Start() works
// on Android with netstack mode. Provides complete no-op stubs for all mobile
// Wire up the default override so embed.Client.Start() works on Android
// with netstack mode. Provides complete no-op stubs for all mobile
// dependencies so the engine's existing Android code paths work unchanged.
// Applications that need P2P ICE or real DNS should replace this by setting
// androidMobileDep before calling Start().
androidMobileDep = func(config *profilemanager.Config) MobileDependency {
return mobileDependencyForEmbed(
// Applications that need P2P ICE or real DNS should replace this by
// setting androidRunOverride before calling Start().
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
return c.runOnAndroidEmbed(
noopIFaceDiscover{},
noopNetworkChangeListener{},
[]netip.AddrPort{},
noopDnsReadyListener{},
runningChan,
logPath,
)
}
}

View File

@@ -10,18 +10,23 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// mobileDependencyForEmbed builds the MobileDependency used by embed.Client on
// Android so the engine's existing Android code paths work unchanged.
func mobileDependencyForEmbed(
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
// so embed.Client.Start() can detect when the engine is ready.
// It provides complete MobileDependency so the engine's existing
// Android code paths work unchanged.
func (c *ConnectClient) runOnAndroidEmbed(
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
) MobileDependency {
return MobileDependency{
runningChan chan struct{},
logPath string,
) error {
mobileDependency := MobileDependency{
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, runningChan, logPath)
}

View File

@@ -1,362 +0,0 @@
package internal
import (
"context"
"errors"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
// errAlreadyRunning is returned when a start is requested while a run is already
// in flight.
var errAlreadyRunning = errors.New("client is already running")
// errNoRunInFlight is returned by waitEstablishedOrDone when no run is active.
var errNoRunInFlight = errors.New("no connection run in flight")
// errStoppedBeforeEstablished is returned when a run ended (cleanly) before the
// connection was established.
var errStoppedBeforeEstablished = errors.New("run stopped before the connection was established")
// lifecycleOp is a serialized lifecycle operation processed by the supervisor.
type lifecycleOp int
const (
opStart lifecycleOp = iota
opStop
opRestart
opStatus
opWaitEstablished
)
// lifecycleCmd is a single lifecycle request handed to the supervisor goroutine.
// They all flow through the same cmdCh so they are strictly ordered (FIFO) with
// respect to each other.
type lifecycleCmd struct {
op lifecycleOp
config *profilemanager.Config
md metadata.MD
mobileDep MobileDependency
logPath string
// done is the caller's notification channel (nil for fire-and-forget). Its
// meaning depends on op:
// - opStart: receives the run's end result when the run terminates, or
// errAlreadyRunning immediately if a run is already in flight.
// - opStop: receives nil once the in-flight run has fully unwound.
// - opWaitEstablished: receives the wait outcome (see waitEstablishedOrDone).
done chan error
reply chan bool // opStatus only: receives whether a run is in flight
waitCtx context.Context // opWaitEstablished only: the waiter's cancellation context
}
// runState holds the lifecycle channels of a single in-flight run, owned by the
// loop goroutine. It never escapes the supervisor as an API; the only readers
// are the per-wait goroutines the loop spawns for opWaitEstablished.
//
// connEstablishedChan is closed by the run once the connection is established.
// The supervisor creates and owns it — callers no longer supply it; they observe
// it through waitEstablishedOrDone. ended is closed (broadcast) when the run
// terminates, so any number of waiters can observe it; err is the run's end
// result, valid only after ended is closed.
type runState struct {
connEstablishedChan chan struct{} // closed by the run on established
ended chan struct{} // closed by finishRun when the run terminates
err error // run end result, valid after ended is closed
}
// runEndResult is sent by the run goroutine to the supervisor when a run ends,
// whether on its own (error / external context cancellation) or because of a Stop.
type runEndResult struct {
err error
}
// runFunc executes a single client run bound to the supervisor-owned context,
// with the config supplied by the start request.
type runFunc func(ctx context.Context, config *profilemanager.Config, mobileDep MobileDependency, connEstablishedChan chan struct{}, logPath string) error
// supervisor serializes start/stop of a single client run. Every request goes
// through cmdCh and is handled one at a time by the loop goroutine, so two
// lifecycle operations can never overlap and their order is preserved (FIFO).
// The loop goroutine is the sole owner of curStart/runCancel, so that state
// needs no locking. The loop exits when the parent context is cancelled.
type supervisor struct {
ctx context.Context
run runFunc
cmdCh chan lifecycleCmd
runEnded chan runEndResult
// owned exclusively by the loop goroutine. curStart is the in-flight start
// command (nil = idle); its done channel is notified when the run ends.
// curRun holds that run's lifecycle channels; runCancel cancels it.
curStart *lifecycleCmd
curRun *runState
runCancel context.CancelFunc
}
func newSupervisor(ctx context.Context, run runFunc) *supervisor {
s := &supervisor{
ctx: ctx,
run: run,
cmdCh: make(chan lifecycleCmd, 16),
runEnded: make(chan runEndResult, 1),
}
go s.loop()
return s
}
func (s *supervisor) loop() {
for {
select {
case <-s.ctx.Done():
s.shutdown()
return
case cmd := <-s.cmdCh:
switch cmd.op {
case opStart:
s.handleStart(cmd)
case opStop:
s.handleStop(cmd)
case opRestart:
s.handleRestart(cmd)
case opStatus:
cmd.reply <- (s.isRunningInternal())
case opWaitEstablished:
s.handleWaitEstablished(cmd)
}
case res := <-s.runEnded:
// Run ended on its own, without an explicit Stop.
s.finishRun(res.err)
}
}
}
func (s *supervisor) handleStart(cmd lifecycleCmd) {
if s.isRunningInternal() {
notify(cmd.done, errAlreadyRunning)
return
}
runCtx, cancel := context.WithCancel(s.ctx)
if cmd.md != nil {
// Carry caller-supplied gRPC metadata (e.g. UI user-agent) into the run
// context so the engine's management/signal calls forward it. The cancel
// still drives runCtx (metadata wrapping preserves cancellation).
runCtx = metadata.NewOutgoingContext(runCtx, cmd.md)
}
s.runCancel = cancel
s.curStart = &cmd
s.curRun = &runState{connEstablishedChan: make(chan struct{}), ended: make(chan struct{})}
go func(ctx context.Context, cfg *profilemanager.Config, m MobileDependency, established chan struct{}, lp string) {
err := s.run(ctx, cfg, m, established, lp)
s.runEnded <- runEndResult{err: err}
}(runCtx, cmd.config, cmd.mobileDep, s.curRun.connEstablishedChan, cmd.logPath)
}
func (s *supervisor) handleStop(cmd lifecycleCmd) {
if !s.isRunningInternal() {
notify(cmd.done, nil)
return
}
s.stopCurrentRun()
notify(cmd.done, nil)
}
// handleRestart tears down any in-flight run and starts a fresh one in a single
// loop turn. No other command can interleave between the stop and the start
// (the loop is single-threaded), so the swap is atomic without relying on any
// daemon-side lock — that is what an explicit restart (e.g. MDM config change)
// needs to avoid a window where the client is observably stopped.
func (s *supervisor) handleRestart(cmd lifecycleCmd) {
if s.isRunningInternal() {
s.stopCurrentRun()
}
s.handleStart(cmd)
}
// stopCurrentRun cancels the in-flight run and blocks the supervisor until it
// has fully unwound, so the next action starts from a clean slate. The run
// goroutine reports completion via runEnded. Caller must hold an in-flight run
// (curStart != nil).
func (s *supervisor) stopCurrentRun() {
s.runCancel()
res := <-s.runEnded
s.finishRun(res.err)
}
// finishRun resets lifecycle state after a run terminates and hands the run
// error back to whoever asked to be notified of the start.
func (s *supervisor) finishRun(err error) {
s.runCancel = nil
if s.isRunningInternal() {
// Publish the result to the broadcast channel before nil-ing curRun, so
// any opWaitEstablished goroutines blocked on ended observe err.
s.curRun.err = err
close(s.curRun.ended)
s.curRun = nil
notify(s.curStart.done, err)
s.curStart = nil
}
}
// handleWaitEstablished answers an opWaitEstablished request. The select itself
// runs in a spawned goroutine on the run's channels so it never blocks the loop;
// the loop only snapshots the in-flight run's channels (which it owns) here.
func (s *supervisor) handleWaitEstablished(cmd lifecycleCmd) {
caller := cmd.done
if !s.isRunningInternal() {
notify(caller, errNoRunInFlight)
return
}
rs := s.curRun
established := rs.connEstablishedChan
ctx := cmd.waitCtx
go func() {
select {
case <-established:
notify(caller, nil)
case <-rs.ended:
if rs.err != nil {
notify(caller, rs.err)
return
}
notify(caller, errStoppedBeforeEstablished)
case <-ctx.Done():
notify(caller, ctx.Err())
}
}()
}
// shutdown tears down the in-flight run when the parent context is cancelled,
// then fails any still-queued commands so their callers never hang.
func (s *supervisor) shutdown() {
if s.runCancel != nil {
s.runCancel()
res := <-s.runEnded
s.finishRun(res.err)
}
for {
select {
case cmd := <-s.cmdCh:
notify(cmd.done, s.ctx.Err())
default:
return
}
}
}
// startAsync enqueues a start without blocking. If done is non-nil it receives
// the run's end result (or errAlreadyRunning on rejection, or the context error
// on shutdown).
func (s *supervisor) startAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string, done chan error) {
cmd := lifecycleCmd{op: opStart, config: config, md: md, mobileDep: mobileDep, logPath: logPath, done: done}
select {
case s.cmdCh <- cmd:
case <-s.ctx.Done():
notify(done, s.ctx.Err())
}
}
// restartAsync enqueues an atomic stop+start without blocking. The supervisor
// tears down any in-flight run and starts a fresh one with the supplied config
// in a single loop turn (see handleRestart). Fire-and-forget: the new run owns
// its lifecycle channels, observed via waitEstablishedOrDone.
func (s *supervisor) restartAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) {
cmd := lifecycleCmd{op: opRestart, config: config, md: md, mobileDep: mobileDep, logPath: logPath}
select {
case s.cmdCh <- cmd:
case <-s.ctx.Done():
}
}
// start enqueues a start and blocks until the run terminates, preserving the
// blocking contract of the legacy Run entry points.
func (s *supervisor) start(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) error {
done := make(chan error, 1)
s.startAsync(config, md, mobileDep, logPath, done)
select {
case err := <-done:
return err
case <-s.ctx.Done():
return s.ctx.Err()
}
}
// isRunning asks the loop whether a run is in flight. The query is serialized
// with start/stop, so during a stop it waits for the teardown to settle and
// then reports the final state — never a transient "half-stopped".
func (s *supervisor) isRunning() bool {
reply := make(chan bool, 1)
select {
case s.cmdCh <- lifecycleCmd{op: opStatus, reply: reply}:
case <-s.ctx.Done():
return false
}
select {
case r := <-reply:
return r
case <-s.ctx.Done():
return false
}
}
func (s *supervisor) isRunningInternal() bool {
return s.curStart != nil
}
// waitEstablishedOrDone blocks until the in-flight run becomes established
// (returns nil) or ends before that (returns the run error, or
// errStoppedBeforeEstablished on a clean stop), or ctx is cancelled. Returns
// errNoRunInFlight if no run is in flight. The wait is performed by a goroutine
// spawned inside the loop (see handleWaitEstablished); the run's channels never
// leave the supervisor.
func (s *supervisor) waitEstablishedOrDone(ctx context.Context) error {
reply := make(chan error, 1)
select {
case s.cmdCh <- lifecycleCmd{op: opWaitEstablished, waitCtx: ctx, done: reply}:
case <-ctx.Done():
return ctx.Err()
case <-s.ctx.Done():
return s.ctx.Err()
}
select {
case err := <-reply:
return err
case <-s.ctx.Done():
return s.ctx.Err()
}
}
// stop enqueues a stop and blocks until the in-flight run is fully torn down.
func (s *supervisor) stop() error {
done := make(chan error, 1)
select {
case s.cmdCh <- lifecycleCmd{op: opStop, done: done}:
case <-s.ctx.Done():
return s.ctx.Err()
}
select {
case err := <-done:
return err
case <-s.ctx.Done():
return s.ctx.Err()
}
}
// notify sends on a caller-supplied channel without blocking. The channel is
// expected to be buffered (cap 1); a nil channel means the caller did not ask
// to be notified.
func notify(ch chan error, err error) {
if ch == nil {
return
}
select {
case ch <- err:
default:
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
@@ -38,11 +39,15 @@ const (
// defaultWarningDelayBase is the starting grace window before a
// "Nameserver group unreachable" event fires for a group that's
// never been healthy and only has overlay upstreams with no
// Connected peer. Per-server and overridable; see warningDelayFor.
defaultWarningDelayBase = 30 * time.Second
// Connected peer. Per-server and overridable via envWarningDelay;
// see warningDelay.
defaultWarningDelayBase = 60 * time.Second
// warningDelayBonusCap caps the route-count bonus added to the
// base grace window. See warningDelayFor.
// base grace window. See warningDelay.
warningDelayBonusCap = 30 * time.Second
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
)
// errNoUsableNameservers signals that a merged-domain group has no usable
@@ -135,7 +140,7 @@ type DefaultServer struct {
disableSys bool
mux sync.Mutex
service service
dnsMuxMap registeredHandlerMap
dnsMuxHandlers []handlerWrapper
localResolver *local.Resolver
wgInterface WGIface
hostManager hostManager
@@ -199,8 +204,6 @@ type handlerWrapper struct {
priority int
}
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
WgInterface WGIface
@@ -289,7 +292,6 @@ func newDefaultServer(
service: dnsService,
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(),
wgInterface: wgInterface,
statusRecorder: statusRecorder,
@@ -298,7 +300,7 @@ func newDefaultServer(
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
warningDelayBase: defaultWarningDelayBase,
warningDelayBase: warningDelayBaseFromEnv(),
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
@@ -328,7 +330,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
type routeSettable interface {
setSelectedRoutes(func() route.HAMap)
}
for _, entry := range s.dnsMuxMap {
for _, entry := range s.dnsMuxHandlers {
if h, ok := entry.handler.(routeSettable); ok {
h.setSelectedRoutes(selected)
}
@@ -978,19 +980,23 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap {
for _, existing := range s.dnsMuxHandlers {
s.deregisterHandler([]string{existing.domain}, existing.priority)
existing.handler.Stop()
// The local resolver is a persistent singleton shared by every custom
// zone and reused across config updates. Its chain registrations are
// per-config and must be deregistered, but Stop() cancels its lookup
// context (breaking external CNAME-target resolution) and clears its
// records, so it must not be torn down here.
if existing.handler != s.localResolver {
existing.handler.Stop()
}
}
muxUpdateMap := make(registeredHandlerMap)
for _, update := range muxUpdates {
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update
}
s.dnsMuxMap = muxUpdateMap
s.dnsMuxHandlers = muxUpdates
}
// updateNSGroupStates records the new group set and pokes the refresher.
@@ -1154,6 +1160,26 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
return false
}
// warningDelayBaseFromEnv returns the base grace window, honoring
// envWarningDelay when it holds a valid positive Go duration. Invalid or
// non-positive values fall back to defaultWarningDelayBase.
func warningDelayBaseFromEnv() time.Duration {
val := os.Getenv(envWarningDelay)
if val == "" {
return defaultWarningDelayBase
}
d, err := time.ParseDuration(val)
if err != nil {
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
return defaultWarningDelayBase
}
if d <= 0 {
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
return defaultWarningDelayBase
}
return d
}
// warningDelay returns the grace window for the given selected-route
// count. Scales gently: +1s per 100 routes, capped by
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
@@ -1204,7 +1230,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
// in more than one handler.
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
merged := make(map[netip.AddrPort]UpstreamHealth)
for _, entry := range s.dnsMuxMap {
for _, entry := range s.dnsMuxHandlers {
reporter, ok := entry.handler.(upstreamHealthReporter)
if !ok {
continue

View File

@@ -104,19 +104,6 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
@@ -132,22 +119,20 @@ func TestUpdateDNSServer(t *testing.T) {
},
}
dummyHandler := local.NewResolver()
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registeredHandlerMap
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -169,20 +154,17 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
dummyHandler.ID(): handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
@@ -191,10 +173,10 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
@@ -215,15 +197,13 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
"local-resolver": handlerWrapper{
{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
},
@@ -232,7 +212,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
@@ -240,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -262,7 +242,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -284,7 +264,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -301,42 +281,41 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap),
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: dummyHandler,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap),
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
@@ -393,7 +372,7 @@ func TestUpdateDNSServer(t *testing.T) {
}
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
@@ -405,14 +384,20 @@ func TestUpdateDNSServer(t *testing.T) {
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
@@ -512,8 +497,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
}()
dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
@@ -1029,15 +1014,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {}
func TestDefaultServer_UpdateMux(t *testing.T) {
baseMatchHandlers := registeredHandlerMap{
"upstream-group1": {
baseMatchHandlers := []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
},
"upstream-group2": {
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
@@ -1046,15 +1031,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
},
}
baseRootHandlers := registeredHandlerMap{
"upstream-root1": {
baseRootHandlers := []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
"upstream-root2": {
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
@@ -1063,22 +1048,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
},
}
baseMixedHandlers := registeredHandlerMap{
"upstream-group1": {
baseMixedHandlers := []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
},
"upstream-group2": {
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
},
"upstream-other": {
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
@@ -1089,7 +1074,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
tests := []struct {
name string
initialHandlers registeredHandlerMap
initialHandlers []handlerWrapper
updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain
description string
@@ -1373,32 +1358,38 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &DefaultServer{
dnsMuxMap: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
dnsMuxHandlers: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Perform the update
server.updateMux(tt.updates)
// Verify the results
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
"Number of handlers after update doesn't match expected")
// Check each expected handler
for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id)
if exists {
assert.Equal(t, expectedDomain, handler.domain,
var found *handlerWrapper
for i := range server.dnsMuxHandlers {
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
found = &server.dnsMuxHandlers[i]
break
}
}
assert.NotNil(t, found, "Expected handler %s not found", id)
if found != nil {
assert.Equal(t, expectedDomain, found.domain,
"Domain mismatch for handler %s", id)
}
}
// Verify no unexpected handlers exist
for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
for _, entry := range server.dnsMuxHandlers {
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
}
// Verify the handlerChain state and order
@@ -1413,7 +1404,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Verify handler exists in mux
foundInMux := false
for _, muxEntry := range server.dnsMuxMap {
for _, muxEntry := range server.dnsMuxHandlers {
if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
@@ -1422,12 +1413,108 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}
}
assert.True(t, foundInMux,
"Handler in chain not found in dnsMuxMap")
"Handler in chain not found in dnsMuxHandlers")
}
})
}
}
// chainHasPattern reports whether the handler chain holds an entry registered
// for the given fqdn pattern at the given priority.
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
for _, h := range s.handlerChain.handlers {
if h.OrigPattern == pattern && h.Priority == priority {
return true
}
}
return false
}
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
// tracks each (handler, domain) registration independently when one handler
// serves multiple zones. Every custom zone is served by the same handler
// instance (the local resolver, whose ID is the constant "local-resolver"), so
// removing one zone must deregister exactly that zone's chain entry and leave
// the others in place. Tracking registrations by handler ID alone collapses all
// zones onto one entry, leaving removed zones in the chain to answer
// authoritatively with no records.
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
// One handler serves every custom zone, mirroring s.localResolver.
shared := &mockHandler{Id: "local-resolver"}
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Two custom zones under the same handler. The surviving zone is registered
// last, mirroring the management emission order.
server.updateMux([]handlerWrapper{
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test should be registered after the first update")
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should be registered after the first update")
// Remove one zone, keep the other.
server.updateMux([]handlerWrapper{
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should remain after removing userzone.test")
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test handler must be deregistered, not leaked in the chain")
}
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
// does not tear down the shared local resolver during reconfiguration. The
// resolver is a process-lifetime singleton reused across config updates;
// Stop() cancels its lookup context (breaking external CNAME-target
// resolution) and clears its records. updateMux must deregister its chain
// entries without stopping it. Records surviving a teardown update is the
// observable proxy: Stop() would have cleared them.
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
resolver := local.NewResolver()
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
Name: "peer.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}))
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
localResolver: resolver,
}
server.updateMux([]handlerWrapper{
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
})
// Remove the zone. The resolver must survive so its records and lookup
// context stay intact for the next registration.
server.updateMux(nil)
var response *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
response = m
return nil
},
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
require.NotNil(t, response, "local resolver should answer after teardown")
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
}
func TestExtraDomains(t *testing.T) {
tests := []struct {
name string
@@ -2049,7 +2136,6 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
@@ -2207,7 +2293,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
}
}
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
@@ -2283,12 +2369,11 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
@@ -2395,7 +2480,6 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
@@ -2407,7 +2491,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2444,7 +2528,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
@@ -2459,7 +2542,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2484,6 +2567,32 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestWarningDelayBaseFromEnv(t *testing.T) {
tests := []struct {
name string
set bool
val string
want time.Duration
}{
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envWarningDelay, tc.val)
if !tc.set {
os.Unsetenv(envWarningDelay)
}
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
})
}
}
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
@@ -2595,7 +2704,6 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
@@ -2613,7 +2721,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2640,7 +2748,6 @@ func TestDNSLoopPrevention(t *testing.T) {
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
tests := []struct {

View File

@@ -443,21 +443,25 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
// A valid response means the upstream is reachable, whatever the Rcode.
u.markUpstreamOk(upstream)
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
// refused zones, transient recursion errors), not reachability
// problems: fail over for a better answer but keep the upstream healthy.
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
@@ -465,7 +469,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}

View File

@@ -517,6 +517,78 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
// those are per-question outcomes from a reachable server and must not mark
// the upstream unhealthy. Only transport failures (timeouts) do.
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.10:53")
b := netip.MustParseAddrPort("192.0.2.11:53")
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
tests := []struct {
name string
respA mockUpstreamResponse
respB mockUpstreamResponse
wantHealthy bool
}{
{
name: "both SERVFAIL are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
wantHealthy: true,
},
{
name: "both REFUSED are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
wantHealthy: true,
},
{
name: "timeout marks unhealthy",
respA: mockUpstreamResponse{err: timeoutErr},
respB: mockUpstreamResponse{err: timeoutErr},
wantHealthy: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): tc.respA,
b.String(): tc.respB,
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a, b})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, a, "primary upstream should have a health record")
if tc.wantHealthy {
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
} else {
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
}
})
}
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string

View File

@@ -22,8 +22,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
@@ -1127,20 +1125,6 @@ func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
}
// wrapDisconnectError classifies a receive-loop failure before the run is torn
// down. An auth rejection (PermissionDenied/Unauthenticated) means the session
// needs re-login and retrying is futile, so mark it terminal (NeedsLogin) — run()
// then exits on its own instead of spinning the backoff. Any other failure is a
// recoverable connection reset that the backoff should retry.
func (e *Engine) wrapDisconnectError(err error) {
state := CtxGetState(e.ctx)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied || s.Code() == codes.Unauthenticated) {
state.Set(StatusNeedsLogin)
return
}
_ = state.Wrap(ErrResetConnection)
}
func (e *Engine) receiveJobEvents() {
e.jobExecutorWG.Add(1)
go func() {
@@ -1167,9 +1151,9 @@ func (e *Engine) receiveJobEvents() {
}
})
if err != nil {
// happens if management is unavailable for a long time, or rejects
// us (auth). wrapDisconnectError decides retry vs needs-login.
e.wrapDisconnectError(err)
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
@@ -1251,9 +1235,9 @@ func (e *Engine) receiveManagementEvents() {
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time, or rejects
// us (auth). wrapDisconnectError decides retry vs needs-login.
e.wrapDisconnectError(err)
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
@@ -1777,9 +1761,9 @@ func (e *Engine) receiveSignalEvents() {
return nil
})
if err != nil {
// happens if signal is unavailable for a long time, or rejects us
// (auth). wrapDisconnectError decides retry vs needs-login.
e.wrapDisconnectError(err)
// happens if signal is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}

View File

@@ -171,13 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
connectClient := internal.NewConnectClient(ctx, c.recorder)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, connectClient)
// Persist the latest sync response so DebugBundle can include the network
// map. On iOS this is backed by disk to keep it out of the constrained
// process memory (see the syncstore package).
connectClient.SetSyncResponsePersistence(true)
return connectClient.RunOniOS(cfg, fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
}
// Stop the internal client and free the resources

View File

@@ -344,6 +344,9 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
}
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
if s.connectClient == nil {
return nil, status.Error(codes.FailedPrecondition, "client not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")

View File

@@ -5,6 +5,7 @@ package server
import (
"bytes"
"context"
"errors"
"fmt"
"runtime/pprof"
@@ -27,9 +28,11 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
}
var clientMetrics debug.MetricsExporter
if engine := s.connectClient.Engine(); engine != nil {
if cm := engine.GetClientMetrics(); cm != nil {
clientMetrics = cm
if s.connectClient != nil {
if engine := s.connectClient.Engine(); engine != nil {
if cm := engine.GetClientMetrics(); cm != nil {
clientMetrics = cm
}
}
}
@@ -45,10 +48,13 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
defer s.cleanupBundleCapture()
var refreshStatus func()
if engine := s.connectClient.Engine(); engine != nil {
refreshStatus = func() {
log.Debug("refreshing system health status for debug bundle")
engine.RunHealthProbes(true)
if s.connectClient != nil {
engine := s.connectClient.Engine()
if engine != nil {
refreshStatus = func() {
log.Debug("refreshing system health status for debug bundle")
engine.RunHealthProbes(true)
}
}
}
@@ -112,7 +118,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
log.SetLevel(level)
s.connectClient.SetLogLevel(level)
if s.connectClient != nil {
s.connectClient.SetLogLevel(level)
}
log.Infof("Log level set to %s", level.String())
@@ -126,13 +134,20 @@ func (s *Server) SetSyncResponsePersistence(_ context.Context, req *proto.SetSyn
enabled := req.GetEnabled()
s.persistSyncResponse = enabled
s.connectClient.SetSyncResponsePersistence(enabled)
if s.connectClient != nil {
s.connectClient.SetSyncResponsePersistence(enabled)
}
return &proto.SetSyncResponsePersistenceResponse{}, nil
}
func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
return s.connectClient.GetLatestSyncResponse()
cClient := s.connectClient
if cClient == nil {
return nil, errors.New("connect client is not initialized")
}
return cClient.GetLatestSyncResponse()
}
// StartCPUProfile starts CPU profiling in the daemon.

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
@@ -38,11 +39,12 @@ type conflictCheck struct {
// OS-native managed-config store reports a diff vs the last observation.
//
// Restart sequence:
// 1. Stop the in-flight run via the supervisor (blocks until fully torn down).
// 2. Re-resolve Config from disk + MDM policy (Config.apply re-runs
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
// applyMDMPolicy with the freshly loaded Policy).
// 3. Start a fresh run with the new config.
// 4. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
// RPC) can refresh its cached config view without polling.
//
// The callback runs in the ticker's own goroutine. Ticker has already
@@ -50,24 +52,39 @@ type conflictCheck struct {
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
log.Warn("MDM policy changed; restarting engine to apply new configuration")
// Hold s.mutex for the entire restart sequence (stop + re-start). Any
// concurrent Up/Down/Status arriving while MDM is restarting blocks on the
// Lock until we are done — they then observe the post-restart state coherently.
// Hold s.mutex for the entire restart sequence (cancel + quiescence
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
// MDM is restarting blocks on the Lock until we are done — they
// then observe the post-restart state coherently. This is safe
// because the connectWithRetryRuns goroutine no longer acquires
// s.mutex in its defer (intent vs. goroutine-alive concerns are
// fully separated; see the connectionGoroutineRunning helper).
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.connectClient.ConnectionRunning() {
// No run in flight, so there's no engine to restart.
if !s.clientRunning {
// The client is not running, so there's no engine to restart.
return nil
}
// Cancel daemon-side login/status activities tied to the old run; the run
// itself is torn down atomically by the supervisor inside Restart (see
// restartEngineForMDMLocked), which stops and re-starts in one operation.
if s.actCancel != nil {
s.actCancel()
}
// Wait for previous connectWithRetryRuns to exit so we don't end up
// with two goroutines fighting over the same status recorder + engine.
// The teardown engages a fan-out of engine goroutines (peer workers,
// signal handler, route manager, ...). close(clientGiveUpChan)
// happens in the function-scope defer of connectWithRetryRuns, on
// every exit path (ctx cancel, backoff exhausted, panic) — see the
// defer in server.go.
if s.clientGiveUpChan != nil {
select {
case <-s.clientGiveUpChan:
case <-time.After(10 * time.Second):
return fmt.Errorf("failed to restart the engine due to timeout")
}
}
if err := s.restartEngineForMDMLocked(); err != nil {
log.Errorf("MDM restart failed: %v", err)
return err
@@ -114,13 +131,14 @@ func (s *Server) publishConfigChangedEvent(source string) {
}
// restartEngineForMDMLocked re-resolves the active profile config
// (re-running applyMDMPolicy via Config.apply) and starts a fresh run.
// Mirrors the tail of Server.Start so a runtime MDM change behaves
// identically to a fresh boot under the new policy.
// (re-running applyMDMPolicy via Config.apply) and re-spawns
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
// MDM change behaves identically to a fresh boot under the new policy.
//
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
// for the entire restart sequence so concurrent Up/Down/Status RPCs
// observe a coherent post-restart state.
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
// state.
func (s *Server) restartEngineForMDMLocked() error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil {
@@ -136,13 +154,13 @@ func (s *Server) restartEngineForMDMLocked() error {
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
_, cancel := context.WithCancel(s.rootCtx)
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
log.Info("MDM restart: atomically restarting the run with re-resolved config")
// MDM restart has no incoming RPC metadata; fire and forget. Restart is a
// single supervisor op (atomic stop+start), so there is no observable
// "stopped" window between tearing down the old run and starting the new.
s.connectClient.Restart(config, nil)
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
s.publishConfigChangedEvent("mdm")
return nil
}

View File

@@ -34,6 +34,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
@@ -143,6 +147,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
@@ -191,6 +199,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")

View File

@@ -8,10 +8,12 @@ import (
"os"
"os/exec"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
@@ -37,7 +39,15 @@ import (
)
const (
probeThreshold = time.Second * 5
probeThreshold = time.Second * 5
retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME"
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
defaultInitialRetryTime = 30 * time.Minute
defaultMaxRetryInterval = 60 * time.Minute
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon (disabled by default)
defaultJWTCacheTTL = 0
@@ -62,8 +72,15 @@ type Server struct {
mutex sync.Mutex
config *profilemanager.Config
proto.UnimplementedDaemonServiceServer
// Run state (in-flight? established/done channels?) is owned entirely by the
// supervisor inside connectClient — the daemon keeps no per-run fields.
// clientRunning tracks "the daemon wants to be connected" — set true by
// Start / Up, cleared by Down / Logout. Persists across retry
// loops, signal disconnects, and ErrResetConnection cycles. NOT
// changed by connectWithRetryRuns goroutine exit — for that
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
// derives from clientGiveUpChan close state. Protected by s.mutex.
clientRunning bool
clientRunningChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
connectClient *internal.ConnectClient
@@ -119,13 +136,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
networksDisabled: networksDisabled,
jwtCache: newJWTCache(),
}
// The ConnectClient is daemon-lifetime: build it exactly once, here. Its
// supervisor lives as long as the daemon; Up/Down/MDM and reconnects all
// drive this same instance. updateManager isn't ready yet (created in
// Start) and is injected there via SetUpdateManager.
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
agent := &serverAgent{s}
s.sleepHandler = sleephandler.New(agent)
s.startSleepDetector()
@@ -137,7 +147,7 @@ func (s *Server) Start() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.connectClient.ConnectionRunning() {
if s.clientRunning {
return nil
}
@@ -155,7 +165,6 @@ func (s *Server) Start() error {
stateMgr := statemanager.New(s.profileManager.GetStatePath())
s.updateManager = updater.NewManager(s.statusRecorder, stateMgr)
s.updateManager.CheckUpdateSuccess(s.rootCtx)
s.connectClient.SetUpdateManager(s.updateManager)
}
// MDM policy reload ticker: every minute the desktop daemon re-reads
@@ -181,9 +190,7 @@ func (s *Server) Start() error {
return nil
}
// actCancel cancels in-flight foreground operations (login/status); the run
// itself is owned by the supervisor and stopped via Stop, not this cancel.
_, cancel := context.WithCancel(s.rootCtx)
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
// copy old default config
@@ -225,14 +232,99 @@ func (s *Server) Start() error {
return nil
}
// Boot autoconnect: no incoming RPC metadata. The supervisor runs the
// client and reconnects internally; we just fire and forget (the run owns
// its established/done channels).
s.connectClient.RunAsync(config, nil)
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
s.publishConfigChangedEvent("startup")
return nil
}
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
//
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
// — placed in the function-scope defer so every return path (panic,
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
// it. Callers that need to observe "is the goroutine still alive?" use
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
// goroutine.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
if giveUpChan != nil {
close(giveUpChan)
}
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
return
}
backOff := getConnectWithBackoff(ctx)
go func() {
t := time.NewTicker(24 * time.Hour)
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState()
if mgmtState.Connected && signalState.Connected {
log.Tracef("resetting status")
backOff.Reset()
} else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
}
}
}
}()
runOperation := func() error {
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
}
log.Tracef("client connection exited gracefully, do not need to retry")
return nil
}
if err := backoff.Retry(runOperation, backOff); err != nil {
log.Errorf("operation failed: %v", err)
}
// giveUpChan is closed by the function-scope defer.
}
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
// still running. Returns false when no goroutine has ever been started
// AND when the most recent one has already closed clientGiveUpChan on
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
// completion, or backoff retry exhaustion).
//
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
// is written by Start/Up under the same lock.
func (s *Server) connectionGoroutineRunning() bool {
if s.clientGiveUpChan == nil {
return false
}
select {
case <-s.clientGiveUpChan:
return false
default:
return true
}
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
@@ -628,22 +720,13 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
// Up starts engine work in the daemon.
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
s.mutex.Lock()
// The client (and its supervisor) is built once in New(), so a nil here
// never happens in production — Up is only reachable after New() has run and
// the gRPC server is serving. The real case this guards is the daemon
// SHUTTING DOWN: rootCtx is cancelled, the supervisor is no longer accepting
// commands, so ServiceRunning() is false even though the client exists. Bail
// loud instead of enqueuing a run that will never start. (nil only happens in
// tests that build a Server without New(); ServiceRunning is nil-safe.)
if !s.connectClient.ServiceRunning() {
s.mutex.Unlock()
return nil, fmt.Errorf("service is not running, start the netbird service for 'up' to take effect")
}
// If a connection run is already in flight, the existing engine is on the
// job — just wait for it. Otherwise fall through to start a fresh run.
if s.connectClient.ConnectionRunning() {
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
// goroutine is still trying. When intent is up AND goroutine is alive,
// the existing engine is on the job — just wait for it. When intent
// is up but the goroutine has given up (backoff exhausted) OR when
// intent is down, fall through to spawn a fresh retry loop.
if s.clientRunning && s.connectionGoroutineRunning() {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
@@ -681,13 +764,13 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
if s.actCancel != nil {
s.actCancel()
}
// actCancel cancels in-flight foreground ops (login/status); the run is
// owned by the supervisor and stopped via Stop, not this cancel.
_, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
ctx, cancel := context.WithCancel(s.rootCtx)
md, ok := metadata.FromIncomingContext(callerCtx)
if ok {
ctx = metadata.NewOutgoingContext(ctx, md)
}
// Forward the caller's gRPC metadata (e.g. UI user-agent) into the run.
md, _ := metadata.FromIncomingContext(callerCtx)
s.actCancel = cancel
if s.config == nil {
s.mutex.Unlock()
@@ -729,26 +812,35 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
s.connectClient.RunAsync(s.config, md)
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
s.publishConfigChangedEvent("up_rpc")
s.mutex.Unlock()
return s.waitForUp(callerCtx)
}
// waitForUp blocks until the in-flight run becomes established (success) or ends
// before that (failure). The wait is owned by the supervisor (via the client) —
// the daemon holds no per-run state here.
// todo: handle potential race conditions
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel()
if err := s.connectClient.WaitEstablishedOrDone(timeoutCtx); err != nil {
log.Debugf("waiting for the connection to be established failed: %v", err)
return nil, fmt.Errorf("connection not established: %w", err)
select {
case <-s.clientGiveUpChan:
return nil, fmt.Errorf("client gave up to connect")
case <-s.clientRunningChan:
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil
case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready")
return nil, callerCtx.Err()
case <-timeoutCtx.Done():
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
}
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil
}
// resolveProfileHandle resolves a wire-level profile handle (display
@@ -843,11 +935,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
// Down engine work in the daemon.
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
// cleanupConnection stops the run through the supervisor, which blocks until
// the run has fully unwound — no separate goroutine-quiescence wait needed.
giveUpChan := s.clientGiveUpChan
if err := s.cleanupConnection(); err != nil {
s.mutex.Unlock()
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
@@ -856,6 +948,20 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.mutex.Unlock()
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
// The giveUpChan is closed at the end of connectWithRetryRuns.
if giveUpChan != nil {
select {
case <-giveUpChan:
log.Debugf("client goroutine finished successfully")
case <-time.After(5 * time.Second):
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
}
}
return &proto.DownResponse{}, nil
}
@@ -866,19 +972,34 @@ func (s *Server) cleanupConnection() error {
return ErrServiceNotUp
}
// Tear the client down through the lifecycle supervisor BEFORE cancelling
// the retry context. Stop serializes on the supervisor queue and blocks
// until the in-flight run has fully unwound (a clean, synchronous teardown).
// It must run before actCancel: cancelling the context first would make
// Stop observe a dead context and return early without waiting.
if err := s.connectClient.Stop(); err != nil {
return err
// Daemon intent flips to "down" — all callers (Down RPC,
// Logout RPC handlers) tear down the connection because the user
// explicitly asked for it. MDM restart does NOT go through this
// path, so its clientRunning stays true.
s.clientRunning = false
// Capture the engine reference before cancelling the context.
// After actCancel(), the connectWithRetryRuns goroutine wakes up
// and sets connectClient.engine = nil, causing connectClient.Stop()
// to skip the engine shutdown entirely.
var engine *internal.Engine
if s.connectClient != nil {
engine = s.connectClient.Engine()
}
// Stop the retry goroutine so it does not start a fresh run. The client
// itself is daemon-lifetime and intentionally kept (a later Up reuses it).
s.actCancel()
if s.connectClient == nil {
return nil
}
if engine != nil {
if err := engine.Stop(); err != nil {
return err
}
}
s.connectClient = nil
s.isSessionActive.Store(false)
log.Infof("service is down")
@@ -1013,7 +1134,7 @@ func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfi
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.ID == profile.ID && s.connectClient.ConnectionRunning() {
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
return s.sendLogoutRequest(ctx)
}
@@ -1059,13 +1180,48 @@ func (s *Server) Status(
ctx context.Context,
msg *proto.StatusRequest,
) (*proto.StatusResponse, error) {
// A run that hits a terminal auth failure now exits on its own (engine marks
// NeedsLogin), so we no longer poll-and-cancel: we wait for the in-flight run
// to become established or to end. With no run in flight this returns
// immediately (errNoRunInFlight); either way we then report the status below.
if msg.WaitForReady != nil && *msg.WaitForReady {
if err := s.connectClient.WaitEstablishedOrDone(ctx); err != nil && ctx.Err() != nil {
return nil, ctx.Err()
s.mutex.Lock()
// Only wait if the retry-loop goroutine is alive and making
// progress. clientRunning=true with connectionGoroutineRunning=false means the
// backoff has given up — there is nothing to wait for; let the
// caller observe the failed status directly.
alive := s.connectionGoroutineRunning()
s.mutex.Unlock()
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
return nil, err
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-s.clientGiveUpChan:
ticker.Stop()
break loop
case <-s.clientRunningChan:
ticker.Stop()
break loop
case <-ticker.C:
status, err := state.Status()
if err != nil {
continue
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
continue
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
@@ -1103,6 +1259,10 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
@@ -1140,6 +1300,10 @@ func (s *Server) GetPeerSSHHostKey(
statusRecorder := s.statusRecorder
s.mutex.Unlock()
if connectClient == nil {
return nil, errors.New("client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, errors.New("engine not started")
@@ -1306,13 +1470,17 @@ func (s *Server) WaitJWTToken(
// ExposeService exposes a local port via the NetBird reverse proxy.
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
s.mutex.Lock()
if !s.connectClient.ConnectionRunning() {
if !s.clientRunning {
s.mutex.Unlock()
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
}
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
@@ -1366,6 +1534,10 @@ func isUnixRunningDesktop() bool {
}
func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
}
engine := s.connectClient.Engine()
if engine == nil {
return
@@ -1644,6 +1816,22 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
log.Tracef("running client connection")
client := internal.NewConnectClient(ctx, config, statusRecorder)
client.SetUpdateManager(s.updateManager)
client.SetSyncResponsePersistence(s.persistSyncResponse)
s.mutex.Lock()
s.connectClient = client
s.mutex.Unlock()
if err := client.Run(runningChan, s.logFile); err != nil {
return err
}
return nil
}
// MDM authority: when the platform-native MDM source sets a kill switch
// key (regardless of true/false value), that value wins. The CLI flag
// supplied at service install time is the fallback used only when the
@@ -1705,6 +1893,45 @@ func (s *Server) onSessionExpire() {
}
}
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {

View File

@@ -15,19 +15,14 @@ import (
)
func newTestServer() *Server {
ctx := context.Background()
s := &Server{
rootCtx: ctx,
return &Server{
rootCtx: context.Background(),
statusRecorder: peer.NewRecorder(""),
}
// Honor the production invariant: the daemon-lifetime client always exists
// (built in New). Server methods rely on s.connectClient being non-nil.
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
return s
}
func newDummyConnectClient(ctx context.Context) *internal.ConnectClient {
return internal.NewConnectClient(ctx, nil)
return internal.NewConnectClient(ctx, nil, nil)
}
// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient
@@ -92,36 +87,41 @@ func TestConcurrentConnectClientAccess(t *testing.T) {
assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic")
}
// TestCleanupConnection_KeepsClientStopsRunning validates that cleanupConnection
// clears the daemon "up" intent but KEEPS the daemon-lifetime ConnectClient
// (it is reused across Up/Down; only the run is stopped).
func TestCleanupConnection_KeepsClientStopsRunning(t *testing.T) {
// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection
// properly nils out connectClient.
func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
s := newTestServer()
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
s.connectClient = newDummyConnectClient(context.Background())
s.clientRunning = true
err := s.cleanupConnection()
require.NoError(t, err)
assert.NotNil(t, s.connectClient, "connectClient is daemon-lifetime and must persist after cleanup")
assert.False(t, s.connectClient.ConnectionRunning(), "no run should be in flight after cleanup")
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
}
// TestCleanState_NotConnected validates that CleanState doesn't panic when no
// connection run is in flight.
func TestCleanState_NotConnected(t *testing.T) {
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
// when connectClient is nil.
func TestCleanState_NilConnectClient(t *testing.T) {
s := newTestServer()
s.profileManager = nil // will cause error if it tries to proceed
s.connectClient = nil
s.profileManager = nil // will cause error if it tries to proceed past the nil check
// Should not panic — the nil check should prevent calling Status() on nil
assert.NotPanics(t, func() {
_, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true})
})
}
// TestDeleteState_NotConnected validates that DeleteState doesn't panic when no
// connection run is in flight.
func TestDeleteState_NotConnected(t *testing.T) {
// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic
// when connectClient is nil.
func TestDeleteState_NilConnectClient(t *testing.T) {
s := newTestServer()
s.connectClient = nil
s.profileManager = nil
assert.NotPanics(t, func() {
@@ -129,6 +129,60 @@ func TestDeleteState_NotConnected(t *testing.T) {
})
}
// TestDownThenUp_StaleRunningChan documents the known state issue where
// clientRunningChan from a previous connection is already closed, causing
// waitForUp() to return immediately on reconnect.
func TestDownThenUp_StaleRunningChan(t *testing.T) {
s := newTestServer()
// Simulate state after a successful connection
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
close(s.clientRunningChan) // closed when engine started
s.clientGiveUpChan = make(chan struct{})
s.connectClient = newDummyConnectClient(context.Background())
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
// Simulate Down(): cleanupConnection sets connectClient = nil and
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
// remains independent of intent — derived from clientGiveUpChan.
s.mutex.Lock()
err := s.cleanupConnection()
s.mutex.Unlock()
require.NoError(t, err)
// After cleanup: connectClient is nil, clientRunning is false (intent
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
// (goroutine teardown is independent of the intent flag).
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
s.mutex.Unlock()
// waitForUp() returns immediately due to stale closed clientRunningChan
ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer ctxCancel()
waitDone := make(chan error, 1)
go func() {
_, err := s.waitForUp(ctx)
waitDone <- err
}()
select {
case err := <-waitDone:
assert.NoError(t, err, "waitForUp returns success on stale channel")
// But connectClient is still nil — this is the stale state issue
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success")
s.mutex.Unlock()
case <-time.After(1 * time.Second):
t.Fatal("waitForUp should have returned immediately due to stale closed channel")
}
}
// TestConnectClient_EngineNilOnFreshClient validates that a newly created
// ConnectClient has nil Engine (before Run is called).
func TestConnectClient_EngineNilOnFreshClient(t *testing.T) {

View File

@@ -31,6 +31,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/server"
@@ -60,6 +61,65 @@ var (
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
ic := profilemanager.ConfigInput{
ManagementURL: "http://" + mgmtAddr,
ConfigPath: t.TempDir() + "/test-profile.json",
}
config, err := profilemanager.UpdateOrCreateConfig(ic)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
currUser, err := user.Current()
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false, false)
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
func TestServer_Up(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir

View File

@@ -9,6 +9,7 @@ import (
"google.golang.org/grpc/status"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/proto"
@@ -37,7 +38,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro
// CleanState handles cleaning of states (performing cleanup operations)
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
if s.connectClient.ConnectionRunning() {
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
@@ -80,7 +81,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (
// DeleteState handles deletion of states without cleanup
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
if s.connectClient.ConnectionRunning() {
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}

View File

@@ -62,6 +62,10 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
}
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
if s.connectClient == nil {
return nil, nil, fmt.Errorf("connect client not initialized")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, nil, fmt.Errorf("engine not initialized")

View File

@@ -1,56 +0,0 @@
# Build environments
Dockerfiles that pin the same toolchain CI uses, so a developer can
reproduce a CI build locally without installing platform SDKs on their
workstation. The version pins in each `Dockerfile` must stay in lockstep
with `.github/workflows/`.
## `android/`
Mirrors `.github/workflows/mobile-build-validation.yml` (`android_build`
job). Carries Go 1.25.5, Adopt JDK 11, Android cmdline-tools 8512546,
NDK 23.1.7779620 and gomobile pinned at the CI commit. Use it to
produce `netbird.aar` from `./client/android`:
```bash
docker build -t netbird/build-android docker/build-env/android
docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
gomobile bind \
-o netbird.aar \
-javapkg=io.netbird.gomobile \
-ldflags="-checklinkname=0 \
-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
-X github.com/netbirdio/netbird/version.version=local" \
./client/android
```
To build the full Android APK, bind-mount the `android-client` repo as
well and run its own `./gradlew assembleDebug` from inside the
container (the gradle wrapper ships with `android-client`).
## `windows-cross/`
Cross-compiles Windows binaries from Linux using `mingw-w64`. Lets you
verify that `GOOS=windows go build ./...` compiles cleanly without
needing a Windows VM. Cannot run Windows tests — the `golang-test-windows`
CI job executes on a native `windows-latest` runner with wintun.dll
and PsExec, neither of which lives under Linux containers.
```bash
docker build -t netbird/build-windows docker/build-env/windows-cross
docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
```
## What is NOT here
- **iOS / macOS**: cannot legally run macOS in Docker (Apple EULA),
and Xcode is not redistributable. The `ios_build` CI job uses a
`macos-latest` GitHub runner; locally you need a real Mac.
- **Native Windows tests**: see note above. The Linux+mingw image
builds, it does not execute Windows-host code paths
(registry, wintun, services, PsExec workflows).
When CI version pins change, update the corresponding `ARG` lines in
the Dockerfiles and the README's table of versions.

View File

@@ -1,86 +0,0 @@
# Android build environment.
#
# Mirrors the toolchain pinned by .github/workflows/mobile-build-validation.yml
# so a `gomobile bind` against ./client/android in this image produces the
# same netbird.aar that CI builds.
#
# Tooling versions (must stay in sync with the CI workflow):
# - Ubuntu 22.04 (matches the ubuntu-latest GitHub runner)
# - Go 1.25.5 (matches go.mod)
# - Adopt JDK 11 (matches actions/setup-java@v3 java-version: 11, distribution: adopt)
# - Android SDK cmdline-tools 8512546
# - Android NDK 23.1.7779620
# - gomobile commit v0.0.0-20251113184115-a159579294ab
#
# Usage (from the netbird repo root):
#
# docker build -t netbird/build-android docker/build-env/android
#
# # bind the netbird checkout in and run the same gomobile command CI runs
# docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
# gomobile bind \
# -o netbird.aar \
# -javapkg=io.netbird.gomobile \
# -ldflags="-checklinkname=0 \
# -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
# -X github.com/netbirdio/netbird/version.version=local" \
# ./client/android
#
# To build the full APK, mount the android-client repo too and run
# `./gradlew assembleDebug` from /android-client (this image carries
# gradle's prerequisites JDK + Android SDK but not the gradle wrapper —
# that ships with android-client).
FROM ubuntu:22.04
ARG DEBIAN_FRONTEND=noninteractive
# Versions — bump in lockstep with .github/workflows/mobile-build-validation.yml.
ARG GO_VERSION=1.25.5
ARG ANDROID_CMDLINE_TOOLS_VERSION=8512546
ARG ANDROID_NDK_VERSION=23.1.7779620
ARG GOMOBILE_VERSION=v0.0.0-20251113184115-a159579294ab
ENV ANDROID_HOME=/opt/android-sdk
ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${ANDROID_NDK_VERSION}
ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
ENV GOPATH=/go
ENV GOTOOLCHAIN=local
ENV CGO_ENABLED=0
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${ANDROID_HOME}/cmdline-tools/latest/bin:${ANDROID_HOME}/platform-tools:${JAVA_HOME}/bin:${PATH}
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
curl \
unzip \
git \
openjdk-11-jdk-headless \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Install Go (matches go.mod). actions/setup-go fetches the same tarball.
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
| tar -C /usr/local -xz \
&& go version
# Install Android SDK command-line tools, accept licenses, install NDK.
RUN mkdir -p "${ANDROID_HOME}/cmdline-tools" \
&& curl -fsSL -o /tmp/cmdline.zip \
"https://dl.google.com/android/repository/commandlinetools-linux-${ANDROID_CMDLINE_TOOLS_VERSION}_latest.zip" \
&& unzip -q /tmp/cmdline.zip -d "${ANDROID_HOME}/cmdline-tools" \
&& mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" \
&& rm /tmp/cmdline.zip \
&& yes | sdkmanager --licenses > /dev/null \
&& sdkmanager --install "ndk;${ANDROID_NDK_VERSION}" "platform-tools" > /dev/null
# Install gomobile at the same commit CI pins. Don't run `gomobile init` here:
# `init` resolves the NDK at runtime, do it on the first bind in the mounted
# workspace so the cache lands on the host volume.
RUN GOBIN=/usr/local/bin go install "golang.org/x/mobile/cmd/gomobile@${GOMOBILE_VERSION}" \
&& gomobile version
WORKDIR /src
# Default entrypoint is a plain shell so the image is composable: callers pass
# the full gomobile / gradle command they want to run.
CMD ["/bin/bash"]

View File

@@ -1,63 +0,0 @@
# Windows-cross build environment.
#
# Cross-compiles Windows .exe targets from a Linux container using
# mingw-w64. Mirrors the toolchain set used by
# .github/workflows/golang-test-windows.yml insofar as that is possible
# without a Windows kernel.
#
# IMPORTANT — what this image CAN do:
# - `GOOS=windows go build ./...` to validate that Windows builds compile
# - CGO Windows cross-compile via x86_64-w64-mingw32-gcc when CGO_ENABLED=1
# (matches CI's choco-installed mingw-w64)
#
# IMPORTANT — what this image CANNOT do:
# - Run Windows binaries (no Windows kernel under Docker on Linux).
# - Replicate the CI's `go test` runs which execute on a real
# windows-latest runner (wintun.dll, PsExec, registry, etc.).
# Use the CI for that or a native Windows VM.
#
# Usage (from the netbird repo root):
#
# docker build -t netbird/build-windows docker/build-env/windows-cross
#
# # Cross-compile a static client (.exe) from Linux:
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
# bash -c 'CGO_ENABLED=1 GOOS=windows GOARCH=amd64 \
# CC=x86_64-w64-mingw32-gcc CXX=x86_64-w64-mingw32-g++ \
# go build -o netbird.exe ./client'
#
# # Just validate that everything *compiles* on Windows (no CGO):
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
# bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
#
# Tooling versions (keep in sync with go.mod and any future explicit pin
# documented in golang-test-windows.yml):
# - Ubuntu 22.04
# - Go 1.25.5 (matches go.mod)
# - mingw-w64 (Ubuntu package — pin further if drift becomes a problem)
FROM ubuntu:22.04
ARG DEBIAN_FRONTEND=noninteractive
ARG GO_VERSION=1.25.5
ENV GOPATH=/go
ENV GOTOOLCHAIN=local
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${PATH}
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
curl \
git \
build-essential \
mingw-w64 \
&& rm -rf /var/lib/apt/lists/*
# Install Go (matches go.mod).
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
| tar -C /usr/local -xz \
&& go version
WORKDIR /src
CMD ["/bin/bash"]

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
return srv
}
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -33,6 +33,8 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
proxyauth "github.com/netbirdio/netbird/proxy/auth"
@@ -82,6 +84,9 @@ type ProxyServiceServer struct {
// Manager for users
usersManager users.Manager
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
idpManager idp.Manager
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
@@ -157,7 +162,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -166,6 +171,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
idpManager: idpManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
@@ -1702,22 +1708,7 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
@@ -1754,6 +1745,45 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}, nil
}
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
// user or peer ID, and peer name or user email.
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
// Resolve the principal: when the peer is linked to a user, the human is the
// principal so multiple peers owned by the same user share a single
// identity. Unlinked peers (machine agents) are their own principal keyed on
// peer.ID. displayIdentity is what upstream gateways tag spend with —
// user.Email when linked, peer.Name when not.
// If the peer isn't associated with a user, return the peer info directly.
if peer.UserID == "" {
return peer.ID, peer.Name
}
// Otherwise, if the peer is linked to a user, the user is the principal and
// if an IdP is available, we gather details on the user from it.
principalID := peer.UserID
displayIdentity := peer.Name
// Stored column first (cheap, but often empty for OIDC-provisioned users).
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
// IdP enrichment wins when available — the stored email column is a
// best-effort cache and is frequently empty for OIDC users. Enrichment
// failures must never fail the RPC; we simply keep the stored/peer identity.
if s.idpManager != nil {
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
displayIdentity = ud.Email
} else if uerr != nil {
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
}
}
return principalID, displayIdentity
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security

View File

@@ -3,14 +3,19 @@ package grpc
import (
"context"
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockReverseProxyManager struct {
@@ -137,6 +142,52 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
return user, nil, nil
}
// mockTunnelPeersManager implements only the two peers.Manager methods that
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
// panics if any unexpected method is invoked).
type mockTunnelPeersManager struct {
peers.Manager
peer *peer.Peer
peerErr error
groups []*types.Group
groupsErr error
}
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
return m.peer, m.peerErr
}
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
return m.peer, m.groups, m.groupsErr
}
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
// an IdP that knows nothing about the user.
type mockTunnelIdpManager struct {
idp.Manager
email string
hasData bool
err error
gotCalls int
gotMeta []idp.AppMetadata
}
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
m.gotCalls++
m.gotMeta = append(m.gotMeta, meta)
if m.err != nil {
return nil, m.err
}
if !m.hasData {
// This might not be a thing any of the actual IDP implementations do,
// i.e. return a nil value with no error, but it seems valuable to test
// that behavior here.
return nil, nil //nolint:nilnil
}
return &idp.UserData{ID: userID, Email: m.email}, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
@@ -354,6 +405,163 @@ func TestValidateUserGroupAccess(t *testing.T) {
}
}
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
// (IdP email -> stored User.Email -> peer.Name).
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
const (
domain = "app.example.com"
accountID = "account1"
peerID = "peer1"
peerName = "peer-display-name"
userID = "user1"
)
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
tests := []struct {
name string
peerUserID string
storedUsers map[string]*types.User
storedErr error
noIdP bool
idpEmail string
idpHasData bool
idpErr error
expectEmail string
expectUserID string
expectIdPHit bool
}{
{
name: "idp email wins over stored email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp returns empty email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "",
idpHasData: true,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp has no data",
peerUserID: userID,
storedUsers: storedUser,
idpHasData: false,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp errors",
peerUserID: userID,
storedUsers: storedUser,
idpErr: errors.New("idp unreachable"),
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when no idp manager",
peerUserID: userID,
storedUsers: storedUser,
noIdP: true,
expectEmail: "stored@example.com",
expectUserID: userID,
},
{
name: "idp email when stored email is empty",
peerUserID: userID,
storedUsers: storedUserNoEmail,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "idp email when stored user missing keeps peer.UserID as principal",
peerUserID: userID,
storedUsers: map[string]*types.User{},
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "unlinked peer uses peer name and never consults idp",
peerUserID: "",
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: peerName,
expectUserID: peerID,
expectIdPHit: false,
},
{
name: "linked peer with empty stored email and no idp falls back to peer name",
peerUserID: userID,
storedUsers: storedUserNoEmail,
noIdP: true,
expectEmail: peerName,
expectUserID: userID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := &service.Service{Domain: domain, AccountID: accountID}
server := &ProxyServiceServer{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
},
peersManager: &mockTunnelPeersManager{
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
},
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
}
var idpMock *mockTunnelIdpManager
if !tt.noIdP {
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
server.idpManager = idpMock
}
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
Domain: domain,
TunnelIp: "100.64.0.1",
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.True(t, resp.GetValid(), "expected access granted")
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
assert.Equal(t, tt.expectUserID, resp.GetUserId())
if idpMock != nil {
if tt.expectIdPHit {
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
require.Len(t, idpMock.gotMeta, 1)
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
} else {
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
}
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)

View File

@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err

View File

@@ -217,6 +217,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager,
nil,
nil,
nil,
)
proxyService.SetServiceManager(&testServiceManager{store: testStore})

View File

@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {

View File

@@ -982,8 +982,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var peer *nbpeer.Peer
var updated, versionChanged, ipv6CapabilityChanged bool
var err error
var postureChecks []*posture.Checks
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -1011,13 +1009,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return status.NewPeerLoginExpiredError()
}
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
@@ -1025,11 +1018,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err
}
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
}
return nil
})
@@ -1037,6 +1025,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
@@ -1047,9 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1124,7 +1117,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
var peer *nbpeer.Peer
var shouldStorePeer bool
var shouldStorePeer, shouldUpdatePeers bool
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1151,14 +1144,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed {
shouldStorePeer = true
shouldUpdatePeers = true
}
}
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
@@ -1180,7 +1169,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
peer.UpdateMetaIfNew(ctx, login.Meta)
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, false, err
}
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, false, err
}
@@ -1190,7 +1187,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
if isStatusChanged || shouldStorePeer {
if shouldUpdatePeers {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
@@ -1286,12 +1283,22 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, nil, false, nil
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
if err != nil {
return nil, nil, false, err
}
@@ -1299,32 +1306,16 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, postureChecks, enableSSH, nil
}
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
for _, peerID := range peerGroupIDs {
groupIDsMap[peerID] = struct{}{}
}
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return false, err
}
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
for _, g := range peerGroups {
peerGroupIDs[g.ID] = struct{}{}
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
}
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
if len(policies) == 0 {
return nil, nil
}
@@ -1336,11 +1327,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
continue
}
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
if err != nil {
return nil, err
}
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
}
@@ -1353,29 +1340,19 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
}
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil {
return nil, err
}
for _, sourceGroup := range rule.Sources {
group, ok := sourceGroups[sourceGroup]
if !ok {
return nil, fmt.Errorf("failed to check peer in policy source group")
}
if slices.Contains(group.Peers, peerID) {
return policy.SourcePostureChecks, nil
if slices.Contains(peerGroupIDs, sourceGroup) {
return policy.SourcePostureChecks
}
}
}
return nil, nil
return nil
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO

View File

@@ -1,12 +1,16 @@
package peer
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -162,49 +166,7 @@ type PeerSystemMeta struct { //nolint:revive
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
})
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
})
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
sort.Slice(p.Files, func(i, j int) bool {
return p.Files[i].Path < p.Files[j].Path
})
sort.Slice(other.Files, func(i, j int) bool {
return other.Files[i].Path < other.Files[j].Path
})
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.KernelVersion == other.KernelVersion &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.OSVersion == other.OSVersion &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion &&
p.SystemSerialNumber == other.SystemSerialNumber &&
p.SystemProductName == other.SystemProductName &&
p.SystemManufacturer == other.SystemManufacturer &&
p.Environment.Cloud == other.Environment.Cloud &&
p.Environment.Platform == other.Environment.Platform &&
p.Flags.isEqual(other.Flags) &&
capabilitiesEqual(p.Capabilities, other.Capabilities)
return len(metaDiff(p, other)) == 0
}
func (p PeerSystemMeta) isEmpty() bool {
@@ -296,7 +258,7 @@ func (p *Peer) Copy() *Peer {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
if meta.isEmpty() {
return updated, versionChanged
}
@@ -308,14 +270,121 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged boo
meta.UIVersion = p.Meta.UIVersion
}
if p.Meta.isEqual(meta) {
return updated, versionChanged
oldVersion := p.Meta.WtVersion
diff := metaDiff(p.Meta, meta)
if len(diff) != 0 {
p.Meta = meta
updated = true
}
p.Meta = meta
updated = true
versionInfo := ""
if versionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if len(diff) > 0 || versionChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
}
return updated, versionChanged
}
// metaDiff returns a human-readable list of the fields that differ between the
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
// source of truth for meta comparison: isEqual reports equality as an empty
// diff, so the log line can never disagree with the change decision. Slices are
// cloned before sorting, so callers' meta is not mutated.
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
var diff []string
add := func(field string, oldVal, newVal any) {
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
return diff
}
// sameMultiset reports whether two slices contain the same elements with the
// same multiplicity, ignoring order. The element type is the comparison key, so
// every field participates in equality.
func sameMultiset[T comparable](a, b []T) bool {
if len(a) != len(b) {
return false
}
counts := make(map[T]int, len(a))
for _, v := range a {
counts[v]++
}
for _, v := range b {
counts[v]--
if counts[v] == 0 {
delete(counts, v)
}
}
return len(counts) == 0
}
// GetLastLogin returns the last login time of the peer.
func (p *Peer) GetLastLogin() time.Time {
if p.LastLogin != nil {

View File

@@ -0,0 +1,113 @@
package peer
import (
"net/netip"
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
// map 1:1 to a single diff entry. Today the only such field is Environment, which
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
// entry beyond its single struct field. If you teach metaDiff to explode another
// field into N entries, bump this by N-1; if you collapse a field, lower it.
const metaDiffExtraEntries = 1
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
// values and diffs it against the zero value, then asserts metaDiff emits exactly
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
//
// The expected count is derived from the struct via reflection, so adding a field
// to PeerSystemMeta raises the expectation automatically — but the actual diff
// only grows if metaDiff was taught to compare the new field. A mismatch means
// someone changed the struct without updating metaDiff (or this test's
// extra-entry accounting), which is exactly what we want to catch.
func TestMetaDiff_CoversAllFields(t *testing.T) {
var full PeerSystemMeta
exported := populateAll(t, reflect.ValueOf(&full).Elem())
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
diff := metaDiff(PeerSystemMeta{}, full)
require.Len(t, diff, exported+metaDiffExtraEntries,
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
"diff was: %v", diff)
require.False(t, full.isEqual(PeerSystemMeta{}),
"isEqual must report a fully-populated meta as different from the zero value")
}
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
// compare would not change the diff count. This flips each Flags field on its own
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
// fails here.
func TestFlags_isEqualChecksEveryField(t *testing.T) {
typ := reflect.TypeOf(Flags{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
require.Equal(t, reflect.Bool, f.Type.Kind(),
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
var a, b Flags
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
}
}
// populateAll sets every exported field of the struct to a deterministic non-zero
// value, recursing into nested structs and the element type of struct slices so
// that each leaf differs from zero. It returns the number of exported fields on
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
// settable exported fields and is comparable with ==).
func populateAll(t *testing.T, v reflect.Value) int {
t.Helper()
typ := v.Type()
exported := 0
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath != "" { // unexported
continue
}
exported++
setNonZero(t, v.Field(i))
}
return exported
}
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
// recursing into nested structs and populating one element of slice fields.
func setNonZero(t *testing.T, field reflect.Value) {
t.Helper()
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
return
}
switch field.Kind() {
case reflect.String:
field.SetString("non-zero")
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Float32, reflect.Float64:
field.SetFloat(7)
case reflect.Struct:
populateAll(t, field)
case reflect.Slice:
s := reflect.MakeSlice(field.Type(), 1, 1)
setNonZero(t, s.Index(0))
field.Set(s)
default:
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
}
}

View File

@@ -125,6 +125,7 @@ func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
oidcConfig,
nil,
usersManager,
nil,
realProxyManager,
nil,
)

View File

@@ -140,6 +140,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
oidcConfig,
nil,
usersManager,
nil,
proxyManager,
nil,
)

View File

@@ -21,7 +21,8 @@ AWK_FIRST_FIELD='{print $1}'
fetch_all_tags() {
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
grep -iv 'rc' | \
sed 's/.*\/v//' | \
sort -u -V
return 0

View File

@@ -32,7 +32,8 @@ fetch_current_ports_version() {
fetch_all_tags() {
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
grep -iv 'rc' | \
sed 's/.*\/v//' | \
sort -u -V
return 0

View File

@@ -140,12 +140,7 @@ func newRotatedOutput(logPath string) io.Writer {
func setGRPCLibLogger(logger *log.Logger) {
logOut := logger.Writer()
if os.Getenv("GRPC_GO_LOG_SEVERITY_LEVEL") != "info" {
// Discard grpc info AND warning logs by default — the warning stream is
// dominated by benign connection-retry noise ("addrConn.createTransport
// failed", "transport is closing") that surfaces e.g. when the CLI dials
// a daemon that is still starting or already gone. Errors are kept. Set
// GRPC_GO_LOG_SEVERITY_LEVEL=info to get the full verbose grpc logging.
grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, io.Discard, logOut))
grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, logOut, logOut))
return
}