Compare commits

..

15 Commits

Author SHA1 Message Date
Theodor S. Midtlien
a04b36e516 Add tests for rotated files in bundle 2026-05-21 19:29:41 +02:00
Theodor S. Midtlien
ece84cac57 Add logging tests 2026-05-21 19:05:16 +02:00
Theodor S. Midtlien
f9311fa7ab Fix permissions 2026-05-21 17:24:29 +02:00
Theodor S. Midtlien
b7a4bd9d70 Add timberjack 2026-05-21 15:16:11 +02:00
Theodor S. Midtlien
09df098f13 Add all uncompressed logs and fix service status 2026-05-21 12:09:27 +02:00
Theodor S. Midtlien
a02767eacb Fix daemon and cli version in bundle and status 2026-05-20 18:09:08 +02:00
Theodor S. Midtlien
c8746851e3 Fix log file open and bundle uncompressed files 2026-05-20 15:22:20 +02:00
Theodor S. Midtlien
31c30e6d87 Remove directory creation 2026-05-20 13:50:05 +02:00
Theodor S. Midtlien
f375b86a79 Refactor log disable check 2026-05-20 11:36:53 +02:00
Theodor S. Midtlien
5a9a3e1717 Add uncompressed logrotate files to bundle 2026-05-20 11:21:35 +02:00
Theodor S. Midtlien
911f1e5fe9 Add env flag and refactor 2026-05-20 10:45:12 +02:00
Theodor S. Midtlien
b26dac7a46 Fix comment 2026-05-19 21:53:43 +02:00
Theodor S. Midtlien
18c5460831 Ensure that log directory exists 2026-05-19 19:37:57 +02:00
Theodor S. Midtlien
9d74bada48 Add logrotate conflict detection 2026-05-19 18:57:44 +02:00
Theodor S. Midtlien
e80cd79355 Add daemon version to bundle 2026-05-19 11:17:15 +02:00
44 changed files with 681 additions and 1517 deletions

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/version"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -101,6 +101,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
Anonymize: anonymizeFlag,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
CliVersion: version.NetbirdVersion(),
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
@@ -109,41 +110,16 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())
if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
return emitDebugBundle(cmd, resp.GetPath(), resp.GetUploadedKey())
}
func emitDebugBundle(cmd *cobra.Command, path, uploadedKey string) error {
if !jsonFlag && !yamlFlag {
cmd.Printf("Local file:\n%s\n", path)
if uploadBundleFlag {
cmd.Printf("Upload file key:\n%s\n", uploadedKey)
}
return nil
}
out := &nbstatus.DebugBundleOutput{Path: path}
if uploadBundleFlag {
out.UploadedKey = uploadedKey
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
if jsonFlag {
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
return nil
}
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
return nil
}
@@ -324,6 +300,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
Anonymize: anonymizeFlag,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
CliVersion: version.NetbirdVersion(),
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
@@ -458,6 +435,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
SyncResponse: syncResponse,
LogPath: logFilePath,
CPUProfile: nil,
DaemonVersion: version.NetbirdVersion(), // acting as daemon
},
debug.BundleConfig{
IncludeSystemInfo: true,
@@ -477,9 +455,6 @@ func init() {
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
debugBundleCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
debugBundleCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
debugBundleCmd.MarkFlagsMutuallyExclusive("json", "yaml")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")

View File

@@ -10,7 +10,6 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
)
var downCmd = &cobra.Command{
@@ -45,29 +44,7 @@ var downCmd = &cobra.Command{
return err
}
out := &nbstatus.DownOutput{Status: "Disconnected"}
switch {
case jsonFlag:
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
case yamlFlag:
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
default:
cmd.Println(out.Status)
}
cmd.Println("Disconnected")
return nil
},
}
func init() {
downCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
downCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
downCmd.MarkFlagsMutuallyExclusive("json", "yaml")
}

View File

@@ -8,7 +8,6 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
)
var forwardingRulesCmd = &cobra.Command{
@@ -26,12 +25,6 @@ var forwardingRulesListCmd = &cobra.Command{
RunE: listForwardingRules,
}
func init() {
forwardingRulesListCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
forwardingRulesListCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
forwardingRulesListCmd.MarkFlagsMutuallyExclusive("json", "yaml")
}
func listForwardingRules(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
@@ -45,23 +38,19 @@ func listForwardingRules(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
rules := resp.GetRules()
sortForwardingRules(rules)
if jsonFlag || yamlFlag {
return emitForwardingList(cmd, rules)
}
if len(rules) == 0 {
if len(resp.GetRules()) == 0 {
cmd.Println("No forwarding rules available.")
return nil
}
printForwardingRules(cmd, rules)
printForwardingRules(cmd, resp.GetRules())
return nil
}
func sortForwardingRules(rules []*proto.ForwardingRule) {
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
// Sort rules by translated address
sort.Slice(rules, func(i, j int) bool {
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
@@ -69,45 +58,9 @@ func sortForwardingRules(rules []*proto.ForwardingRule) {
if rules[i].GetProtocol() != rules[j].GetProtocol() {
return rules[i].GetProtocol() < rules[j].GetProtocol()
}
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
})
}
func emitForwardingList(cmd *cobra.Command, rules []*proto.ForwardingRule) error {
out := &nbstatus.ForwardingListOutput{Rules: make([]nbstatus.ForwardingRuleOutput, 0, len(rules))}
for _, rule := range rules {
row := nbstatus.ForwardingRuleOutput{
TranslatedAddress: rule.GetTranslatedAddress(),
TranslatedHostname: rule.GetTranslatedHostname(),
Protocol: rule.GetProtocol(),
}
if s, ok := portToStringOpt(rule.GetDestinationPort()); ok {
row.DestinationPort = &s
}
if s, ok := portToStringOpt(rule.GetTranslatedPort()); ok {
row.TranslatedPort = &s
}
out.Rules = append(out.Rules, row)
}
if jsonFlag {
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
return nil
}
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
return nil
}
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
var lastIP string
for _, rule := range rules {
@@ -134,22 +87,12 @@ func getFirstPort(portInfo *proto.PortInfo) int {
}
func portToString(translatedPort *proto.PortInfo) string {
if s, ok := portToStringOpt(translatedPort); ok {
return s
}
return "No port specified"
}
// portToStringOpt returns the formatted port string and whether port info was
// actually present. Used by the structured (json/yaml) output so the absent
// case becomes a missing field instead of a sentinel string.
func portToStringOpt(p *proto.PortInfo) (string, bool) {
switch v := p.GetPortSelection().(type) {
switch v := translatedPort.PortSelection.(type) {
case *proto.PortInfo_Port:
return fmt.Sprintf("%d", v.Port), true
return fmt.Sprintf("%d", v.Port)
case *proto.PortInfo_Range_:
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd()), true
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
default:
return "", false
return "No port specified"
}
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
)
@@ -28,10 +27,6 @@ func init() {
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
loginCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
loginCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
loginCmd.MarkFlagsMutuallyExclusive("json", "yaml")
}
var loginCmd = &cobra.Command{
@@ -78,34 +73,10 @@ var loginCmd = &cobra.Command{
return fmt.Errorf("daemon login failed: %v", err)
}
return emitLoginOutput(cmd, activeProf.Name)
},
}
// emitLoginOutput writes the result of a successful login in the format
// requested by the user (json, yaml, or human-readable text fallback).
func emitLoginOutput(cmd *cobra.Command, profileName string) error {
out := &nbstatus.LoginOutput{
Status: "logged_in",
ProfileName: profileName,
}
switch {
case jsonFlag:
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
case yamlFlag:
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
default:
cmd.Println("Logging successfully")
}
return nil
return nil
},
}
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
@@ -282,7 +253,8 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
return emitLoginOutput(cmd, activeProf.Name)
cmd.Println("Logging successfully")
return nil
}
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
@@ -365,11 +337,6 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
}
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
if jsonFlag || yamlFlag {
emitSSOEvent(cmd, verificationURIComplete, userCode)
return
}
var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
@@ -399,33 +366,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
}
}
// emitSSOEvent writes the verification URL/code as a structured event for
// callers using --json or --yaml. The browser is intentionally not opened in
// this mode since automation contexts (CI, scripts) typically run headless.
func emitSSOEvent(cmd *cobra.Command, verificationURIComplete, userCode string) {
event := &nbstatus.SSOEvent{
Event: "sso_required",
VerificationURIComplete: verificationURIComplete,
UserCode: userCode,
}
if jsonFlag {
s, err := event.JSON()
if err != nil {
log.Errorf("marshal sso event: %v", err)
return
}
cmd.Println(s)
return
}
s, err := event.YAML()
if err != nil {
log.Errorf("marshal sso event: %v", err)
return
}
cmd.Print(s)
cmd.Println("---")
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -9,7 +9,6 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
)
var logoutCmd = &cobra.Command{
@@ -50,33 +49,11 @@ var logoutCmd = &cobra.Command{
return fmt.Errorf("deregister: %v", err)
}
out := &nbstatus.DeregisterOutput{
Status: "deregistered",
ProfileName: profileName,
}
switch {
case jsonFlag:
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
case yamlFlag:
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
default:
cmd.Println("Deregistered successfully")
}
cmd.Println("Deregistered successfully")
return nil
},
}
func init() {
logoutCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
logoutCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
logoutCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
logoutCmd.MarkFlagsMutuallyExclusive("json", "yaml")
}

View File

@@ -8,7 +8,6 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
)
var appendFlag bool
@@ -49,12 +48,6 @@ var routesDeselectCmd = &cobra.Command{
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
for _, c := range []*cobra.Command{routesListCmd, routesSelectCmd, routesDeselectCmd} {
c.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
c.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
c.MarkFlagsMutuallyExclusive("json", "yaml")
}
}
func networksList(cmd *cobra.Command, _ []string) error {
@@ -70,10 +63,6 @@ func networksList(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if jsonFlag || yamlFlag {
return emitNetworksList(cmd, resp)
}
if len(resp.Routes) == 0 {
cmd.Println("No networks available.")
return nil
@@ -84,40 +73,6 @@ func networksList(cmd *cobra.Command, _ []string) error {
return nil
}
func emitNetworksList(cmd *cobra.Command, resp *proto.ListNetworksResponse) error {
out := &nbstatus.NetworksListOutput{Networks: make([]nbstatus.NetworkOutput, 0, len(resp.GetRoutes()))}
for _, route := range resp.GetRoutes() {
row := nbstatus.NetworkOutput{
ID: route.GetID(),
Range: route.GetRange(),
Domains: route.GetDomains(),
Selected: route.GetSelected(),
}
if resolved := route.GetResolvedIPs(); len(resolved) > 0 {
row.ResolvedIPs = make(map[string][]string, len(resolved))
for d, ipList := range resolved {
row.ResolvedIPs[d] = ipList.GetIps()
}
}
out.Networks = append(out.Networks, row)
}
if jsonFlag {
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
return nil
}
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
return nil
}
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
cmd.Println("Available Networks:")
for _, route := range resp.Routes {
@@ -187,9 +142,10 @@ func networksSelect(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
}
return emitNetworksMutation(cmd, "selected", req, "Networks selected successfully.")
}
cmd.Println("Networks selected successfully.")
return nil
}
func networksDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
@@ -211,35 +167,7 @@ func networksDeselect(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
}
return emitNetworksMutation(cmd, "deselected", req, "Networks deselected successfully.")
}
cmd.Println("Networks deselected successfully.")
func emitNetworksMutation(cmd *cobra.Command, action string, req *proto.SelectNetworksRequest, textFallback string) error {
if !jsonFlag && !yamlFlag {
cmd.Println(textFallback)
return nil
}
out := &nbstatus.NetworksMutationOutput{
Status: action,
All: req.GetAll(),
}
if !req.GetAll() {
out.Networks = req.GetNetworkIDs()
}
if jsonFlag {
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
return nil
}
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
return nil
}

View File

@@ -11,18 +11,9 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
)
func init() {
for _, c := range []*cobra.Command{profileListCmd, profileAddCmd, profileRemoveCmd, profileSelectCmd} {
c.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
c.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
c.MarkFlagsMutuallyExclusive("json", "yaml")
}
}
var profileCmd = &cobra.Command{
Use: "profile",
Short: "Manage NetBird client profiles",
@@ -99,10 +90,6 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
return err
}
if jsonFlag || yamlFlag {
return emitProfileList(cmd, profiles.GetProfiles())
}
// list profiles, add a tick if the profile is active
cmd.Println("Found", len(profiles.Profiles), "profiles:")
for _, profile := range profiles.Profiles {
@@ -117,58 +104,6 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
return nil
}
func emitProfileMutation(cmd *cobra.Command, action, profile, textFallback string) error {
if !jsonFlag && !yamlFlag {
cmd.Println(textFallback)
return nil
}
out := &nbstatus.ProfileMutationOutput{
Status: action,
ProfileName: profile,
}
if jsonFlag {
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
return nil
}
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
return nil
}
func emitProfileList(cmd *cobra.Command, profiles []*proto.Profile) error {
out := &nbstatus.ProfileListOutput{Profiles: make([]nbstatus.ProfileOutput, 0, len(profiles))}
for _, p := range profiles {
out.Profiles = append(out.Profiles, nbstatus.ProfileOutput{
Name: p.GetName(),
Active: p.GetIsActive(),
})
}
if jsonFlag {
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
return nil
}
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
return nil
}
func addProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -197,7 +132,8 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
return err
}
return emitProfileMutation(cmd, "added", profileName, "Profile added successfully: "+profileName)
cmd.Println("Profile added successfully:", profileName)
return nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error {
@@ -228,7 +164,8 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
return err
}
return emitProfileMutation(cmd, "removed", profileName, "Profile removed successfully: "+profileName)
cmd.Println("Profile removed successfully:", profileName)
return nil
}
func selectProfileFunc(cmd *cobra.Command, args []string) error {
@@ -294,5 +231,6 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
}
return emitProfileMutation(cmd, "selected", profileName, "Profile switched successfully to: "+profileName)
cmd.Println("Profile switched successfully to:", profileName)
return nil
}

View File

@@ -102,7 +102,7 @@ func (p *program) Stop(srv service.Service) error {
}
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc, consoleLog bool) (service.Service, error) {
// rootCmd env vars are already applied by PersistentPreRunE.
SetFlagsFromEnvVars(serviceCmd)
@@ -112,8 +112,14 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
return nil, err
}
if err := util.InitLog(logLevel, logFiles...); err != nil {
return nil, fmt.Errorf("init log: %w", err)
if consoleLog {
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
} else {
if err := util.InitLog(logLevel, logFiles...); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
}
cfg, err := newSVCConfig()
@@ -138,7 +144,7 @@ var runCmd = &cobra.Command{
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
if err != nil {
return err
}
@@ -152,7 +158,7 @@ var startCmd = &cobra.Command{
Short: "starts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
if err != nil {
return err
}
@@ -170,7 +176,7 @@ var stopCmd = &cobra.Command{
Short: "stops NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
if err != nil {
return err
}
@@ -188,7 +194,7 @@ var restartCmd = &cobra.Command{
Short: "restarts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
if err != nil {
return err
}
@@ -206,7 +212,7 @@ var svcStatusCmd = &cobra.Command{
Short: "shows NetBird service status",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, true)
if err != nil {
return err
}

View File

@@ -8,7 +8,6 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
)
var (
@@ -76,10 +75,6 @@ func init() {
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
stateListCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
stateListCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
stateListCmd.MarkFlagsMutuallyExclusive("json", "yaml")
}
func stateList(cmd *cobra.Command, _ []string) error {
@@ -99,10 +94,6 @@ func stateList(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
}
if jsonFlag || yamlFlag {
return emitStateList(cmd, resp.GetStates())
}
cmd.Printf("\nStored states:\n\n")
for _, state := range resp.States {
cmd.Printf("- %s\n", state.Name)
@@ -111,28 +102,6 @@ func stateList(cmd *cobra.Command, _ []string) error {
return nil
}
func emitStateList(cmd *cobra.Command, states []*proto.State) error {
out := &nbstatus.StateListOutput{States: make([]nbstatus.StateOutput, 0, len(states))}
for _, s := range states {
out.States = append(out.States, nbstatus.StateOutput{Name: s.GetName()})
}
if jsonFlag {
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
return nil
}
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
return nil
}
func stateClean(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {

View File

@@ -22,7 +22,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util"
@@ -89,9 +88,6 @@ func init() {
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
upCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
upCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
upCmd.MarkFlagsMutuallyExclusive("json", "yaml")
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -100,10 +96,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
cmd.SetOut(cmd.OutOrStdout())
if (jsonFlag || yamlFlag) && foregroundMode {
return fmt.Errorf("--json/--yaml is not supported with --foreground-mode; use daemon mode")
}
err := util.InitLog(logLevel, util.LogConsole)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
@@ -253,10 +245,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
if status.Status == string(internal.StatusConnected) {
if !profileSwitched {
return emitUpOutput(cmd, &nbstatus.UpOutput{
Status: "already_connected",
ProfileName: activeProf.Name,
}, "Already connected")
cmd.Println("Already connected")
return nil
}
if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil {
@@ -283,31 +273,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
if err := doDaemonUp(ctx, cmd, client, pm, activeProf, customDNSAddressConverted, username.Username); err != nil {
return fmt.Errorf("daemon up failed: %v", err)
}
return emitUpOutput(cmd, &nbstatus.UpOutput{
Status: "connected",
ProfileName: activeProf.Name,
}, "Connected")
}
// emitUpOutput writes the result of an up command in the format requested by
// the user (json, yaml, or human-readable text fallback).
func emitUpOutput(cmd *cobra.Command, out *nbstatus.UpOutput, textFallback string) error {
switch {
case jsonFlag:
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
case yamlFlag:
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
default:
cmd.Println(textFallback)
}
cmd.Println("Connected")
return nil
}

View File

@@ -3,7 +3,6 @@ package cmd
import (
"github.com/spf13/cobra"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/version"
)
@@ -11,35 +10,9 @@ var (
versionCmd = &cobra.Command{
Use: "version",
Short: "Print the NetBird's client application version",
RunE: func(cmd *cobra.Command, args []string) error {
Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout())
v := version.NetbirdVersion()
switch {
case jsonFlag:
out := &nbstatus.VersionOutput{Version: v}
s, err := out.JSON()
if err != nil {
return err
}
cmd.Println(s)
case yamlFlag:
out := &nbstatus.VersionOutput{Version: v}
s, err := out.YAML()
if err != nil {
return err
}
cmd.Print(s)
default:
cmd.Println(v)
}
return nil
cmd.Println(version.NetbirdVersion())
},
}
)
func init() {
versionCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display command result in json format")
versionCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display command result in yaml format")
versionCmd.MarkFlagsMutuallyExclusive("json", "yaml")
}

View File

@@ -254,6 +254,8 @@ type BundleGenerator struct {
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
daemonVersion string
cliVersion string
anonymize bool
includeSystemInfo bool
@@ -278,6 +280,8 @@ type GeneratorDependencies struct {
CapturePath string
RefreshStatus func()
ClientMetrics MetricsExporter
DaemonVersion string
CliVersion string
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -299,6 +303,8 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
daemonVersion: deps.DaemonVersion,
cliVersion: deps.CliVersion,
anonymize: cfg.Anonymize,
includeSystemInfo: cfg.IncludeSystemInfo,
@@ -459,9 +465,11 @@ func (g *BundleGenerator) addStatus() error {
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
Anonymize: g.anonymize,
ProfileName: profName,
Anonymize: g.anonymize,
ProfileName: profName,
DaemonVersion: g.daemonVersion,
})
overview.CliVersion = g.cliVersion
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
@@ -1039,7 +1047,8 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
return
}
pattern := filepath.Join(logDir, "client-*.log.gz")
// This regex will match both logs rotated by us and logrotate on linux
pattern := filepath.Join(logDir, "client*.log.*")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
@@ -1072,7 +1081,12 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
for i := 0; i < maxFiles; i++ {
name := filepath.Base(files[i])
if err := g.addSingleLogFileGz(files[i], name); err != nil {
if strings.HasSuffix(name, ".gz") {
err = g.addSingleLogFileGz(files[i], name)
} else {
err = g.addSingleLogfile(files[i], name)
}
if err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}

View File

@@ -0,0 +1,103 @@
package debug
import (
"archive/zip"
"bytes"
"compress/gzip"
"io"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestAddRotatedLogFiles_PicksUpAllVariants asserts that the rotated-log
// glob picks up logs rotated by timberjack (gzipped) and by logrotate (plain
// and gzipped), and skips unrelated files.
func TestAddRotatedLogFiles_PicksUpAllVariants(t *testing.T) {
dir := t.TempDir()
writeFile(t, filepath.Join(dir, "client.log"), "active log\n")
writeFile(t, filepath.Join(dir, "other.log"), "unrelated\n")
timberjackRotated := "client-2026-05-21T10-30-45.000.log.gz"
writeGzFile(t, filepath.Join(dir, timberjackRotated), "timberjack rotated content\n")
logrotatePlain := "client.log.1"
writeFile(t, filepath.Join(dir, logrotatePlain), "logrotate plain content\n")
logrotateGz := "client.log.2.gz"
writeGzFile(t, filepath.Join(dir, logrotateGz), "logrotate gz content\n")
names := runAddRotatedLogFiles(t, dir, 10)
require.Contains(t, names, timberjackRotated, "timberjack rotated file should be in bundle")
require.Contains(t, names, logrotatePlain, "logrotate plain rotated file should be in bundle")
require.Contains(t, names, logrotateGz, "logrotate gzipped rotated file should be in bundle")
require.NotContains(t, names, "client.log", "active log should not be added by addRotatedLogFiles")
require.NotContains(t, names, "other.log", "unrelated files should not be in bundle")
}
// TestAddRotatedLogFiles_RespectsLogFileCount asserts that only the newest
// logFileCount rotated files are bundled, ordered by mtime.
func TestAddRotatedLogFiles_RespectsLogFileCount(t *testing.T) {
dir := t.TempDir()
oldest := filepath.Join(dir, "client.log.3")
middle := filepath.Join(dir, "client.log.2")
newest := filepath.Join(dir, "client.log.1")
writeFile(t, oldest, "old\n")
writeFile(t, middle, "mid\n")
writeFile(t, newest, "new\n")
now := time.Now()
require.NoError(t, os.Chtimes(oldest, now.Add(-2*time.Hour), now.Add(-2*time.Hour)))
require.NoError(t, os.Chtimes(middle, now.Add(-1*time.Hour), now.Add(-1*time.Hour)))
require.NoError(t, os.Chtimes(newest, now, now))
names := runAddRotatedLogFiles(t, dir, 2)
require.Contains(t, names, "client.log.1")
require.Contains(t, names, "client.log.2")
require.NotContains(t, names, "client.log.3", "oldest file should be dropped when logFileCount=2")
}
// runAddRotatedLogFiles calls addRotatedLogFiles against a fresh in-memory
// zip writer and returns the set of entry names that ended up in the archive.
func runAddRotatedLogFiles(t *testing.T, dir string, logFileCount uint32) map[string]struct{} {
t.Helper()
var buf bytes.Buffer
g := &BundleGenerator{
archive: zip.NewWriter(&buf),
logFileCount: logFileCount,
}
g.addRotatedLogFiles(dir)
require.NoError(t, g.archive.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
names := make(map[string]struct{}, len(zr.File))
for _, f := range zr.File {
names[f.Name] = struct{}{}
}
return names
}
func writeFile(t *testing.T, path, content string) {
t.Helper()
require.NoError(t, os.WriteFile(path, []byte(content), 0o644))
}
func writeGzFile(t *testing.T, path, content string) {
t.Helper()
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
_, err := io.WriteString(gw, content)
require.NoError(t, err)
require.NoError(t, gw.Close())
require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o644))
}

View File

@@ -72,6 +72,7 @@ import (
sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/capture"
"github.com/netbirdio/netbird/version"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -1141,6 +1142,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
DaemonVersion: version.NetbirdVersion(),
RefreshStatus: func() {
e.RunHealthProbes(true)
},

View File

@@ -2709,6 +2709,7 @@ type DebugBundleRequest struct {
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
CliVersion string `protobuf:"bytes,6,opt,name=cliVersion,proto3" json:"cliVersion,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -2771,6 +2772,13 @@ func (x *DebugBundleRequest) GetLogFileCount() uint32 {
return 0
}
func (x *DebugBundleRequest) GetCliVersion() string {
if x != nil {
return x.CliVersion
}
return ""
}
type DebugBundleResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
@@ -6475,14 +6483,17 @@ const file_daemon_proto_rawDesc = "" +
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
"\x17ForwardingRulesResponse\x12,\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xb4\x01\n" +
"\x12DebugBundleRequest\x12\x1c\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
"\n" +
"systemInfo\x18\x03 \x01(\bR\n" +
"systemInfo\x12\x1c\n" +
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" +
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" +
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\x12\x1e\n" +
"\n" +
"cliVersion\x18\x06 \x01(\tR\n" +
"cliVersion\"}\n" +
"\x13DebugBundleResponse\x12\x12\n" +
"\x04path\x18\x01 \x01(\tR\x04path\x12 \n" +
"\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" +

View File

@@ -471,6 +471,7 @@ message DebugBundleRequest {
bool systemInfo = 3;
string uploadURL = 4;
uint32 logFileCount = 5;
string cliVersion = 6;
}
message DebugBundleResponse {

View File

@@ -1,17 +1,16 @@
#!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
if ! which realpath >/dev/null 2>&1; then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
script_path=$(dirname "$(realpath "$0")")
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.6.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
cd "$old_pwd"

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/version"
)
// DebugBundle creates a debug bundle and returns the location.
@@ -67,6 +68,8 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
CapturePath: capturePath,
RefreshStatus: refreshStatus,
ClientMetrics: clientMetrics,
DaemonVersion: version.NetbirdVersion(),
CliVersion: req.CliVersion,
},
debug.BundleConfig{
Anonymize: req.GetAnonymize(),

View File

@@ -384,350 +384,6 @@ func (o *OutputOverview) YAML() (string, error) {
return string(yamlBytes), nil
}
// DownOutput is the structured result of a `netbird down` command.
type DownOutput struct {
Status string `json:"status" yaml:"status"`
}
// UpOutput is the final structured result of a `netbird up` command.
type UpOutput struct {
Status string `json:"status" yaml:"status"` // "connected" | "already_connected"
ProfileName string `json:"profileName,omitempty" yaml:"profileName,omitempty"`
}
// JSON returns the UpOutput as a JSON string.
func (o *UpOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the UpOutput as a YAML string.
func (o *UpOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// SSOEvent is emitted before the final UpOutput when interactive SSO is required.
// It carries the verification URL and user code so callers can surface them
// (e.g. to a Slack channel) without parsing free-form text.
type SSOEvent struct {
Event string `json:"event" yaml:"event"` // "sso_required"
VerificationURIComplete string `json:"verificationUriComplete" yaml:"verificationUriComplete"`
UserCode string `json:"userCode,omitempty" yaml:"userCode,omitempty"`
}
// JSON returns the SSOEvent as a JSON string.
func (e *SSOEvent) JSON() (string, error) {
jsonBytes, err := json.Marshal(e)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the SSOEvent as a YAML string.
func (e *SSOEvent) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(e)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// NetworkOutput is one row of a networks-list response.
type NetworkOutput struct {
ID string `json:"id" yaml:"id"`
Range string `json:"range,omitempty" yaml:"range,omitempty"`
Domains []string `json:"domains,omitempty" yaml:"domains,omitempty"`
ResolvedIPs map[string][]string `json:"resolvedIps,omitempty" yaml:"resolvedIps,omitempty"`
Selected bool `json:"selected" yaml:"selected"`
}
// NetworksListOutput is the structured result of `netbird networks list`.
type NetworksListOutput struct {
Networks []NetworkOutput `json:"networks" yaml:"networks"`
}
// JSON returns the NetworksListOutput as a JSON string.
func (o *NetworksListOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the NetworksListOutput as a YAML string.
func (o *NetworksListOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// ForwardingRuleOutput is one row of a forwarding-list response.
// DestinationPort and TranslatedPort are pointers so that absent port info
// (oneof not set in the protobuf) surfaces as a missing field rather than a
// human-readable sentinel string.
type ForwardingRuleOutput struct {
TranslatedAddress string `json:"translatedAddress" yaml:"translatedAddress"`
TranslatedHostname string `json:"translatedHostname,omitempty" yaml:"translatedHostname,omitempty"`
Protocol string `json:"protocol" yaml:"protocol"`
DestinationPort *string `json:"destinationPort,omitempty" yaml:"destinationPort,omitempty"`
TranslatedPort *string `json:"translatedPort,omitempty" yaml:"translatedPort,omitempty"`
}
// ForwardingListOutput is the structured result of `netbird forwarding list`.
type ForwardingListOutput struct {
Rules []ForwardingRuleOutput `json:"rules" yaml:"rules"`
}
// JSON returns the ForwardingListOutput as a JSON string.
func (o *ForwardingListOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the ForwardingListOutput as a YAML string.
func (o *ForwardingListOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// ProfileOutput is one row of a profile-list response.
type ProfileOutput struct {
Name string `json:"name" yaml:"name"`
Active bool `json:"active" yaml:"active"`
}
// ProfileListOutput is the structured result of `netbird profile list`.
type ProfileListOutput struct {
Profiles []ProfileOutput `json:"profiles" yaml:"profiles"`
}
// JSON returns the ProfileListOutput as a JSON string.
func (o *ProfileListOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the ProfileListOutput as a YAML string.
func (o *ProfileListOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// StateOutput is one row of a state-list response.
type StateOutput struct {
Name string `json:"name" yaml:"name"`
}
// StateListOutput is the structured result of `netbird state list`.
type StateListOutput struct {
States []StateOutput `json:"states" yaml:"states"`
}
// JSON returns the StateListOutput as a JSON string.
func (o *StateListOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the StateListOutput as a YAML string.
func (o *StateListOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// DeregisterOutput is the structured result of `netbird deregister`.
type DeregisterOutput struct {
Status string `json:"status" yaml:"status"` // "deregistered"
ProfileName string `json:"profileName,omitempty" yaml:"profileName,omitempty"`
}
// JSON returns the DeregisterOutput as a JSON string.
func (o *DeregisterOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the DeregisterOutput as a YAML string.
func (o *DeregisterOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// LoginOutput is the structured result of `netbird login` after successful auth.
type LoginOutput struct {
Status string `json:"status" yaml:"status"` // "logged_in"
ProfileName string `json:"profileName,omitempty" yaml:"profileName,omitempty"`
}
// JSON returns the LoginOutput as a JSON string.
func (o *LoginOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the LoginOutput as a YAML string.
func (o *LoginOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// ProfileMutationOutput is the structured result of `netbird profile add`,
// `remove`, or `select`.
type ProfileMutationOutput struct {
Status string `json:"status" yaml:"status"` // "added" | "removed" | "selected"
ProfileName string `json:"profileName" yaml:"profileName"`
}
// JSON returns the ProfileMutationOutput as a JSON string.
func (o *ProfileMutationOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the ProfileMutationOutput as a YAML string.
func (o *ProfileMutationOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// NetworksMutationOutput is the structured result of `netbird networks select`
// or `netbird networks deselect`.
type NetworksMutationOutput struct {
Status string `json:"status" yaml:"status"` // "selected" | "deselected"
Networks []string `json:"networks,omitempty" yaml:"networks,omitempty"`
All bool `json:"all,omitempty" yaml:"all,omitempty"`
}
// JSON returns the NetworksMutationOutput as a JSON string.
func (o *NetworksMutationOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the NetworksMutationOutput as a YAML string.
func (o *NetworksMutationOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// DebugBundleOutput is the structured result of `netbird debug bundle`.
type DebugBundleOutput struct {
Path string `json:"path" yaml:"path"`
UploadedKey string `json:"uploadedKey,omitempty" yaml:"uploadedKey,omitempty"`
}
// JSON returns the DebugBundleOutput as a JSON string.
func (o *DebugBundleOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the DebugBundleOutput as a YAML string.
func (o *DebugBundleOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// VersionOutput is the structured result of `netbird version`.
type VersionOutput struct {
Version string `json:"version" yaml:"version"`
}
// JSON returns the VersionOutput as a JSON string.
func (o *VersionOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the VersionOutput as a YAML string.
func (o *VersionOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// JSON returns the DownOutput as a JSON string.
func (o *DownOutput) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the DownOutput as a YAML string.
func (o *DownOutput) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// GeneralSummary returns a general summary of the status overview.
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
var managementConnString string
@@ -891,6 +547,16 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
daemonVersion := "N/A"
if o.DaemonVersion != "" {
daemonVersion = o.DaemonVersion
}
cliVersion := version.NetbirdVersion()
if o.CliVersion != "" {
cliVersion = o.CliVersion
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
@@ -911,8 +577,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"%s"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
o.DaemonVersion,
version.NetbirdVersion(),
daemonVersion,
cliVersion,
o.ProfileName,
managementConnString,
signalConnString,

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
uptypes "github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/version"
)
// Initial state for the debug collection
@@ -462,6 +463,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
request := &proto.DebugBundleRequest{
Anonymize: params.anonymize,
SystemInfo: params.systemInfo,
CliVersion: version.NetbirdVersion(),
}
if params.upload {
@@ -593,6 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
request := &proto.DebugBundleRequest{
Anonymize: anonymize,
SystemInfo: systemInfo,
CliVersion: version.NetbirdVersion(),
}
if uploadURL != "" {

2
go.mod
View File

@@ -24,13 +24,13 @@ require (
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
gopkg.in/natefinch/lumberjack.v2 v2.2.1
)
require (
fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
github.com/DeRuina/timberjack v1.4.2
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.38.3
github.com/aws/aws-sdk-go-v2/config v1.31.6

4
go.sum
View File

@@ -29,6 +29,8 @@ github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/DeRuina/timberjack v1.4.2 h1:4bKlzhKdsR+2oNkgef9mqb4n11ICow8VK88RfzJPzN8=
github.com/DeRuina/timberjack v1.4.2/go.mod h1:RLoeQrwrCGIEF8gO5nV5b/gMD0QIy7bzQhBUgpp1EqE=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0=
@@ -938,8 +940,6 @@ gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=

View File

@@ -11,7 +11,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/store"
)
@@ -48,11 +47,6 @@ type EphemeralManager struct {
lifeTime time.Duration
cleanupWindow time.Duration
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
// no-op when the receiver is nil so deployments without an app
// metrics provider work unchanged.
metrics *telemetry.EphemeralPeersMetrics
}
// NewEphemeralManager instantiate new EphemeralManager
@@ -66,15 +60,6 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer
}
}
// SetMetrics attaches a metrics collector. Safe to call once before
// LoadInitialPeers; later attachment is fine but earlier loads won't be
// reflected in the gauge. Pass nil to detach.
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
e.peersLock.Lock()
e.metrics = m
e.peersLock.Unlock()
}
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
@@ -112,9 +97,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.removePeer(peer.ID) {
e.metrics.DecPending(1)
}
e.removePeer(peer.ID)
// stop the unnecessary timer
if e.headPeer == nil && e.timer != nil {
@@ -140,7 +123,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
}
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
e.metrics.IncPending()
if e.timer == nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
@@ -163,7 +145,6 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
for _, p := range peers {
e.addPeer(p.AccountID, p.ID, t)
}
e.metrics.AddPending(int64(len(peers)))
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
}
@@ -200,15 +181,6 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
e.peersLock.Unlock()
// Drop the gauge by the number of entries we just took off the list,
// regardless of whether the subsequent DeletePeers call succeeds. The
// list invariant is what the gauge tracks; failed delete batches are
// counted separately via CountCleanupError so we can still see them.
if len(deletePeers) > 0 {
e.metrics.CountCleanupRun()
e.metrics.DecPending(int64(len(deletePeers)))
}
peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
@@ -219,10 +191,7 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
e.metrics.CountCleanupError()
continue
}
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
}
}
@@ -242,12 +211,9 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim
e.tailPeer = ep
}
// removePeer drops the entry from the linked list. Returns true if a
// matching entry was found and removed so callers can keep the pending
// metric gauge in sync.
func (e *EphemeralManager) removePeer(id string) bool {
func (e *EphemeralManager) removePeer(id string) {
if e.headPeer == nil {
return false
return
}
if e.headPeer.id == id {
@@ -255,7 +221,7 @@ func (e *EphemeralManager) removePeer(id string) bool {
if e.tailPeer.id == id {
e.tailPeer = nil
}
return true
return
}
for p := e.headPeer; p.next != nil; p = p.next {
@@ -265,10 +231,9 @@ func (e *EphemeralManager) removePeer(id string) bool {
e.tailPeer = p
}
p.next = p.next.next
return true
return
}
}
return false
}
func (e *EphemeralManager) isPeerOnList(id string) bool {

View File

@@ -112,11 +112,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return Create(s, func() ephemeral.Manager {
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
if metrics := s.Metrics(); metrics != nil {
em.SetMetrics(metrics.EphemeralPeersMetrics())
}
return em
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
})
}

View File

@@ -1868,32 +1868,35 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's
// network map and then marks the peer connected with a session token
// derived from syncTime (the moment the gRPC stream opened). Any
// concurrent stream that started earlier loses the optimistic-lock race
// in MarkPeerConnected and bails without writing.
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return peer, netMap, postureChecks, dnsfwdPort, nil
}
// OnPeerDisconnected is invoked when a sync stream ends. It marks the
// peer disconnected only when the stored SessionStartedAt matches the
// nanosecond token derived from streamStartTime — i.e. only when this
// is the stream that currently owns the peer's session. A mismatch
// means a newer stream has already replaced us, so the disconnect is
// dropped.
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
return nil
}
if peer.Status.LastSeen.After(streamStartTime) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
return nil
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
return nil

View File

@@ -61,8 +61,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error

View File

@@ -1305,31 +1305,17 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
}
// MarkPeerDisconnected mocks base method.
func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected.
func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime)
}
// OnPeerDisconnected mocks base method.

View File

@@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1910,16 +1910,15 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
}, false)
require.NoError(t, err, "unable to add peer")
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err, "unable to get peer")
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should equal the token we passed in")
streamStartTime := time.Now().UTC()
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
@@ -1927,127 +1926,49 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.False(t, peer.Status.Connected, "peer should be disconnected")
require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0")
})
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
// Older stream tries to mark disconnect with its own (older) session token —
// fencing kicks in and the write is dropped.
staleStreamStartTime := streamStartTime.Add(-1 * time.Hour)
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected,
"peer should remain connected because the stored session is newer than the disconnect token")
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should still hold the winning stream's token")
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
})
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should equal node2SyncTime token")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should still be connected")
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should NOT be overwritten by stale token from blocked goroutine")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
})
}
// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the
// fencing protocol under contention: many goroutines race to mark the
// same peer connected with distinct session tokens at the same time.
// The contract is that the highest token always wins and is what remains
// in the store, regardless of execution order.
func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
}, false)
require.NoError(t, err, "unable to add peer")
const workers = 16
base := time.Now().UTC().UnixNano()
tokens := make([]int64, workers)
for i := range tokens {
// Spread tokens by 1ms so the comparison is unambiguous; the
// largest is index workers-1.
tokens[i] = base + int64(i)*int64(time.Millisecond)
}
expected := tokens[workers-1]
var ready sync.WaitGroup
ready.Add(workers)
var start sync.WaitGroup
start.Add(1)
var done sync.WaitGroup
done.Add(workers)
// require.* calls t.FailNow which is documented as unsafe from
// non-test goroutines (it calls runtime.Goexit on the wrong stack and
// races with the WaitGroup). Collect errors here and assert from the
// main goroutine after done.Wait().
errs := make(chan error, workers)
for i := 0; i < workers; i++ {
token := tokens[i]
go func() {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
}()
}
ready.Wait()
start.Done()
done.Wait()
close(errs)
for err := range errs {
require.NoError(t, err, "MarkPeerConnected must not error under contention")
}
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected after the race")
require.Equal(t, expected, peer.Status.SessionStartedAt,
"the largest token must win regardless of execution order")
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -2070,7 +1991,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
"path"
"strings"
"github.com/dexidp/dex/storage"
@@ -140,13 +138,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
}
// Build CLI redirect URIs including the device callback. Dex uses the issuer-relative
// path (for example, /oauth2/device/callback) when completing the device flow, so
// include it explicitly in addition to the legacy bare path and absolute URL.
// Build CLI redirect URIs including the device callback (both relative and absolute)
cliRedirectURIs := c.CLIRedirectURIs
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
cliRedirectURIs = append(cliRedirectURIs, issuerRelativeDeviceCallback(c.Issuer))
cliRedirectURIs = append(cliRedirectURIs, strings.TrimSuffix(c.Issuer, "/")+"/device/callback")
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
// Build dashboard redirect URIs including the OAuth callback for proxy authentication
dashboardRedirectURIs := c.DashboardRedirectURIs
@@ -159,10 +154,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
// MGMT api and the dashboard, adding baseURL means less configuration for the instance admin
dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL)
redirectURIs := make([]string, 0)
redirectURIs = append(redirectURIs, cliRedirectURIs...)
redirectURIs = append(redirectURIs, dashboardRedirectURIs...)
cfg := &dex.YAMLConfig{
Issuer: c.Issuer,
Storage: dex.Storage{
@@ -188,14 +179,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
ID: staticClientDashboard,
Name: "NetBird Dashboard",
Public: true,
RedirectURIs: redirectURIs,
RedirectURIs: dashboardRedirectURIs,
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
},
{
ID: staticClientCLI,
Name: "NetBird CLI",
Public: true,
RedirectURIs: redirectURIs,
RedirectURIs: cliRedirectURIs,
},
},
StaticConnectors: c.StaticConnectors,
@@ -226,14 +217,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
return cfg, nil
}
func issuerRelativeDeviceCallback(issuer string) string {
u, err := url.Parse(issuer)
if err != nil || u.Path == "" {
return "/device/callback"
}
return path.Join(u.Path, "/device/callback")
}
// Due to how the frontend generates the logout, sometimes it appends a trailing slash
// and because Dex only allows exact matches, we need to make sure we always have both
// versions of each provided uri
@@ -316,7 +299,7 @@ func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) {
}
}
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key:%s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
}
func validSessionCookieEncryptionKeyLength(length int) bool {

View File

@@ -314,34 +314,6 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
})
}
func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *testing.T) {
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "https://example.com/oauth2",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(t.TempDir(), "dex.db"),
},
},
}
yamlConfig, err := config.ToYAMLConfig()
require.NoError(t, err)
var cliRedirectURIs []string
for _, client := range yamlConfig.StaticClients {
if client.ID == staticClientCLI {
cliRedirectURIs = client.RedirectURIs
break
}
}
require.NotEmpty(t, cliRedirectURIs)
assert.Contains(t, cliRedirectURIs, "/device/callback")
assert.Contains(t, cliRedirectURIs, "/oauth2/device/callback")
assert.Contains(t, cliRedirectURIs, "https://example.com/oauth2/device/callback")
}
func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) {
t.Setenv(sessionCookieEncryptionKeyEnv, "")

View File

@@ -38,8 +38,7 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
@@ -228,14 +227,7 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
// Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing
// hook so tests that inject MarkPeerDisconnectedFunc actually observe
// disconnect events. Falls through to nil when no hook is set, which
// is the original behaviour.
if am.MarkPeerDisconnectedFunc != nil {
return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano())
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
return nil
}
@@ -336,21 +328,13 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
if am.MarkPeerDisconnectedFunc != nil {
return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
if am.DeleteAccountFunc != nil {

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -28,7 +29,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -63,64 +63,56 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
// MarkPeerConnected marks a peer as connected with optimistic-locked
// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument
// is the start time of the gRPC sync stream that owns this update,
// expressed as Unix nanoseconds — only the call whose token is greater
// than what's stored wins. LastSeen is written by the database itself;
// we never pass it down.
//
// Disconnects use MarkPeerDisconnected and require the session to match
// exactly; see PeerStatus.SessionStartedAt for the protocol.
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
}()
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
// syncTime is used as the LastSeen timestamp and for stale request detection
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
var err error
var skipped bool
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
outcome := telemetry.PeerStatusError
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
outcome = telemetry.PeerStatusPeerNotFound
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, outcome)
return err
}
updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt)
if err != nil {
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusError)
return err
}
if !updated {
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusStale)
log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID)
return nil
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
if am.geo != nil && realIP != nil {
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
}
expired := peer.Status != nil && peer.Status.LoginExpired
if peer.AddedWithSSOLogin() {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
if err != nil {
return err
}
if connected && !syncTime.After(peer.Status.LastSeen) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
skipped = true
return nil
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
return err
})
if skipped {
return nil
}
if err != nil {
return err
}
if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.schedulePeerLoginExpiration(ctx, accountID)
}
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
if expired {
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
@@ -128,60 +120,41 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
// MarkPeerDisconnected marks a peer as disconnected, but only when the
// stored session token matches the one passed in. A mismatch means a
// newer stream has already taken ownership of the peer — disconnects from
// the older stream are ignored. LastSeen is written by the database.
func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusDisconnect, time.Since(start))
}()
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = syncTime
newStatus.Connected = connected
// whenever peer got connected that means that it logged in successfully
if newStatus.Connected {
newStatus.LoginExpired = false
}
peer.Status = newStatus
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
outcome := telemetry.PeerStatusError
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
outcome = telemetry.PeerStatusPeerNotFound
if geo != nil && realIP != nil {
location, err := geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
} else {
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
err = transaction.SavePeerLocation(ctx, accountID, peer)
if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, outcome)
return err
}
updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt)
if err != nil {
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusError)
return err
}
if !updated {
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusStale)
log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping",
peer.ID, sessionStartedAt)
return nil
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
return nil
}
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
// updatePeerLocationIfChanged refreshes the geolocation on a separate
// row update, only when the connection IP actually changed. Geo lookups
// are expensive so we skip same-IP reconnects.
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return
}
location, err := am.geo.Lookup(realIP)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return
}
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
return false, err
}
return oldStatus.LoginExpired, nil
}
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.

View File

@@ -74,19 +74,8 @@ type ProxyMeta struct {
}
type PeerStatus struct { //nolint:revive
// LastSeen is the last time the peer status was updated (i.e. the last
// time we observed the peer being alive on a sync stream). Written by
// the database (CURRENT_TIMESTAMP) — callers do not supply it.
// LastSeen is the last time peer was connected to the management service
LastSeen time.Time
// SessionStartedAt records when the currently-active sync stream began,
// stored as Unix nanoseconds. It acts as the optimistic-locking token
// for status updates: a stream is only allowed to mutate the peer's
// status when its own token strictly exceeds the stored token (when connecting)
// or matches it exactly (for disconnects). Zero means "no
// active session". Integer nanoseconds are used so equality is
// precision-safe across drivers, and so the predicates compose to a
// single bigint comparison.
SessionStartedAt int64
// Connected indicates whether peer is connected to the management service or not
Connected bool
// LoginExpired
@@ -386,14 +375,10 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any {
return meta
}
// Copy PeerStatus. SessionStartedAt must be propagated so clone-based
// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently
// reset the fencing token to zero — that would let any subsequent
// SavePeerStatus write reopen the optimistic-lock window.
// Copy PeerStatus
func (p *PeerStatus) Copy() *PeerStatus {
return &PeerStatus{
LastSeen: p.LastSeen,
SessionStartedAt: p.SessionStartedAt,
Connected: p.Connected,
LoginExpired: p.LoginExpired,
RequiresApproval: p.RequiresApproval,

View File

@@ -498,9 +498,8 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
peerCopy.Status = &peerStatus
fieldsToUpdate := []string{
"peer_status_last_seen", "peer_status_session_started_at",
"peer_status_connected", "peer_status_login_expired",
"peer_status_requires_approval",
"peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval",
}
result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
@@ -517,69 +516,6 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
return nil
}
// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update.
// The peer is marked connected with the given session token only when
// the stored SessionStartedAt is strictly smaller than the incoming
// one — equivalently, when no newer stream has already taken ownership.
// The sentinel zero (set on peer creation or after a disconnect) counts
// as the smallest possible token. This is the write half of the
// fencing protocol described on PeerStatus.SessionStartedAt.
//
// The post-write side effects in the caller — geo lookup,
// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration,
// OnPeersUpdated — all run AFTER this method returns and are deliberately
// outside the database write so they cannot extend the row-lock window.
//
// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the
// moment the row is written. The caller never supplies LastSeen because
// the value would otherwise drift under lock contention — a Go-side
// time.Now() taken before the write can land minutes later than the
// actual UPDATE under load, which previously caused real ordering bugs.
func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
result := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerID).
Where("peer_status_session_started_at < ?", newSessionStartedAt).
Updates(map[string]any{
"peer_status_connected": true,
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
"peer_status_session_started_at": newSessionStartedAt,
"peer_status_login_expired": false,
})
if result.Error != nil {
return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error)
}
return result.RowsAffected > 0, nil
}
// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update.
// The peer is marked disconnected only when the stored SessionStartedAt
// matches the incoming token — meaning the stream that owns the current
// session is the one ending. If a newer stream has already replaced the
// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at
// write time; see MarkPeerConnectedIfNewerSession for the rationale.
//
// A zero sessionStartedAt is rejected at the call site; the underlying
// WHERE on equality would otherwise match every never-connected peer.
func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
if sessionStartedAt == 0 {
return false, nil
}
result := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerID).
Where("peer_status_session_started_at = ?", sessionStartedAt).
Updates(map[string]any{
"peer_status_connected": false,
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
"peer_status_session_started_at": int64(0),
})
if result.Error != nil {
return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error)
}
return result.RowsAffected > 0, nil
}
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
@@ -1787,10 +1723,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname,
meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version,
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at,
peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip,
location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6
FROM peers WHERE account_id = $1`
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -1803,7 +1738,6 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
lastLogin, createdAt sql.NullTime
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
peerStatusLastSeen sql.NullTime
peerStatusSessionStartedAt sql.NullInt64
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool
ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
@@ -1818,9 +1752,8 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
&allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform,
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities,
&peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired,
&peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID,
&proxyEmbedded, &proxyCluster, &ipv6)
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6)
if err == nil {
if lastLogin.Valid {
@@ -1847,9 +1780,6 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
if peerStatusLastSeen.Valid {
p.Status.LastSeen = peerStatusLastSeen.Time
}
if peerStatusSessionStartedAt.Valid {
p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64
}
if peerStatusConnected.Valid {
p.Status.Connected = peerStatusConnected.Bool
}

View File

@@ -167,21 +167,6 @@ type Store interface {
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
// given session token, but only when the stored SessionStartedAt is
// strictly less than newSessionStartedAt (the sentinel zero counts as
// "older"). LastSeen is recorded by the database at the moment the
// row is updated — never by the caller — so it always reflects the
// real write time even under lock contention.
// Returns true when the update happened, false when this stream lost
// the race against a newer session.
MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error)
// MarkPeerDisconnectedIfSameSession sets the peer to disconnected and
// resets SessionStartedAt to zero, but only when the stored
// SessionStartedAt equals the given sessionStartedAt. LastSeen is
// recorded by the database. Returns true when the update happened,
// false when a newer session has taken over.
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error

View File

@@ -2878,36 +2878,6 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status)
}
// MarkPeerConnectedIfNewerSession mocks base method.
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
}
// MarkPeerDisconnectedIfSameSession mocks base method.
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
}
// SavePolicy mocks base method.
func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error {
m.ctrl.T.Helper()

View File

@@ -16,8 +16,6 @@ type AccountManagerMetrics struct {
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
peerMetaUpdateCount metric.Int64Counter
peerStatusUpdateCounter metric.Int64Counter
peerStatusUpdateDurationMs metric.Float64Histogram
}
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
@@ -66,24 +64,6 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
return nil, err
}
// peerStatusUpdateCounter records every attempt to mark a peer as connected or disconnected
peerStatusUpdateCounter, err := meter.Int64Counter("management.account.peer.status.update.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of peer status update attempts, labeled by operation (connect|disconnect) and outcome (applied|stale|error|peer_not_found)"))
if err != nil {
return nil, err
}
peerStatusUpdateDurationMs, err := meter.Float64Histogram("management.account.peer.status.update.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
1, 5, 15, 25, 50, 100, 250, 500, 1000, 2000, 5000,
),
metric.WithDescription("Duration of a peer status update (fence UPDATE + post-write side effects), labeled by operation"))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
@@ -91,35 +71,10 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
updateAccountPeersCounter: updateAccountPeersCounter,
networkMapObjectCount: networkMapObjectCount,
peerMetaUpdateCount: peerMetaUpdateCount,
peerStatusUpdateCounter: peerStatusUpdateCounter,
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
}, nil
}
// PeerStatusOperation labels the kind of fence-locked peer status write.
type PeerStatusOperation string
// PeerStatusOutcome labels how a fence-locked peer status write resolved.
type PeerStatusOutcome string
const (
PeerStatusConnect PeerStatusOperation = "connect"
PeerStatusDisconnect PeerStatusOperation = "disconnect"
// PeerStatusApplied — the fence WHERE matched and the UPDATE landed.
PeerStatusApplied PeerStatusOutcome = "applied"
// PeerStatusStale — the fence WHERE rejected the write because a
// newer session has already taken ownership (connect: stored token
// >= incoming; disconnect: stored token != incoming).
PeerStatusStale PeerStatusOutcome = "stale"
// PeerStatusError — the store returned a non-NotFound error.
PeerStatusError PeerStatusOutcome = "error"
// PeerStatusPeerNotFound — the peer lookup failed (the peer was
// deleted between the gRPC sync handshake and the status write).
PeerStatusPeerNotFound PeerStatusOutcome = "peer_not_found"
)
// CountUpdateAccountPeersDuration counts the duration of updating account peers
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
@@ -149,23 +104,3 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource,
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
}
// CountPeerStatusUpdate increments the connect/disconnect counter,
// labeled by operation and outcome. Both labels are bounded enums.
func (metrics *AccountManagerMetrics) CountPeerStatusUpdate(op PeerStatusOperation, outcome PeerStatusOutcome) {
metrics.peerStatusUpdateCounter.Add(metrics.ctx, 1,
metric.WithAttributes(
attribute.String("operation", string(op)),
attribute.String("outcome", string(outcome)),
),
)
}
// RecordPeerStatusUpdateDuration records the wall-clock time spent
// running a peer status update (including post-write side effects),
// labeled by operation.
func (metrics *AccountManagerMetrics) RecordPeerStatusUpdateDuration(op PeerStatusOperation, d time.Duration) {
metrics.peerStatusUpdateDurationMs.Record(metrics.ctx, float64(d.Nanoseconds())/1e6,
metric.WithAttributes(attribute.String("operation", string(op))),
)
}

View File

@@ -29,7 +29,6 @@ type MockAppMetrics struct {
StoreMetricsFunc func() *StoreMetrics
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
EphemeralPeersMetricsFunc func() *EphemeralPeersMetrics
}
// GetMeter mocks the GetMeter function of the AppMetrics interface
@@ -104,14 +103,6 @@ func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
return nil
}
// EphemeralPeersMetrics mocks the MockAppMetrics function of the EphemeralPeersMetrics interface
func (mock *MockAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
if mock.EphemeralPeersMetricsFunc != nil {
return mock.EphemeralPeersMetricsFunc()
}
return nil
}
// AppMetrics is metrics interface
type AppMetrics interface {
GetMeter() metric2.Meter
@@ -123,7 +114,6 @@ type AppMetrics interface {
StoreMetrics() *StoreMetrics
UpdateChannelMetrics() *UpdateChannelMetrics
AccountManagerMetrics() *AccountManagerMetrics
EphemeralPeersMetrics() *EphemeralPeersMetrics
}
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
@@ -139,7 +129,6 @@ type defaultAppMetrics struct {
storeMetrics *StoreMetrics
updateChannelMetrics *UpdateChannelMetrics
accountManagerMetrics *AccountManagerMetrics
ephemeralMetrics *EphemeralPeersMetrics
}
// IDPMetrics returns metrics for the idp package
@@ -172,11 +161,6 @@ func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetr
return appMetrics.accountManagerMetrics
}
// EphemeralPeersMetrics returns metrics for the ephemeral peer cleanup loop
func (appMetrics *defaultAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
return appMetrics.ephemeralMetrics
}
// Close stop application metrics HTTP handler and closes listener.
func (appMetrics *defaultAppMetrics) Close() error {
if appMetrics.listener == nil {
@@ -261,11 +245,6 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
}
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
@@ -275,7 +254,6 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
ephemeralMetrics: ephemeralMetrics,
}, nil
}
@@ -312,11 +290,6 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
}
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
@@ -327,6 +300,5 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
ephemeralMetrics: ephemeralMetrics,
}, nil
}

View File

@@ -1,115 +0,0 @@
package telemetry
import (
"context"
"go.opentelemetry.io/otel/metric"
)
// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how
// many peers are currently scheduled for deletion, how many tick runs
// the cleaner has performed, how many peers it has removed, and how
// many delete batches failed.
type EphemeralPeersMetrics struct {
ctx context.Context
pending metric.Int64UpDownCounter
cleanupRuns metric.Int64Counter
peersCleaned metric.Int64Counter
errors metric.Int64Counter
}
// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters.
func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) {
pending, err := meter.Int64UpDownCounter("management.ephemeral.peers.pending",
metric.WithUnit("1"),
metric.WithDescription("Number of ephemeral peers currently waiting to be cleaned up"))
if err != nil {
return nil, err
}
cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of ephemeral cleanup ticks that processed at least one peer"))
if err != nil {
return nil, err
}
peersCleaned, err := meter.Int64Counter("management.ephemeral.peers.cleaned.counter",
metric.WithUnit("1"),
metric.WithDescription("Total number of ephemeral peers deleted by the cleanup loop"))
if err != nil {
return nil, err
}
errors, err := meter.Int64Counter("management.ephemeral.cleanup.errors.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of ephemeral cleanup batches (per account) that failed to delete"))
if err != nil {
return nil, err
}
return &EphemeralPeersMetrics{
ctx: ctx,
pending: pending,
cleanupRuns: cleanupRuns,
peersCleaned: peersCleaned,
errors: errors,
}, nil
}
// All methods are nil-receiver safe so callers that haven't wired metrics
// (tests, self-hosted with metrics off) can invoke them unconditionally.
// IncPending bumps the pending gauge when a peer is added to the cleanup list.
func (m *EphemeralPeersMetrics) IncPending() {
if m == nil {
return
}
m.pending.Add(m.ctx, 1)
}
// AddPending bumps the pending gauge by n — used at startup when the
// initial set of ephemeral peers is loaded from the store.
func (m *EphemeralPeersMetrics) AddPending(n int64) {
if m == nil || n <= 0 {
return
}
m.pending.Add(m.ctx, n)
}
// DecPending decreases the pending gauge — used both when a peer reconnects
// before its deadline (removed from the list) and when a cleanup tick
// actually deletes it.
func (m *EphemeralPeersMetrics) DecPending(n int64) {
if m == nil || n <= 0 {
return
}
m.pending.Add(m.ctx, -n)
}
// CountCleanupRun records one cleanup pass that processed >0 peers. Idle
// ticks (nothing to do) deliberately don't increment so the rate
// reflects useful work.
func (m *EphemeralPeersMetrics) CountCleanupRun() {
if m == nil {
return
}
m.cleanupRuns.Add(m.ctx, 1)
}
// CountPeersCleaned records the number of peers a single tick deleted.
func (m *EphemeralPeersMetrics) CountPeersCleaned(n int64) {
if m == nil || n <= 0 {
return
}
m.peersCleaned.Add(m.ctx, n)
}
// CountCleanupError records a failed delete batch.
func (m *EphemeralPeersMetrics) CountCleanupError() {
if m == nil {
return
}
m.errors.Add(m.ctx, 1)
}

View File

@@ -1,15 +1,16 @@
package util
import (
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strconv"
"github.com/DeRuina/timberjack"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/grpclog"
"gopkg.in/natefinch/lumberjack.v2"
"github.com/netbirdio/netbird/formatter"
)
@@ -59,7 +60,12 @@ func InitLogger(logger *log.Logger, logLevel string, logs ...string) error {
case "":
logger.Warnf("empty log path received: %#v", logPath)
default:
writers = append(writers, newRotatedOutput(logPath))
writer, err := setupLogFile(logPath, isRotationDisabled(logger))
if err != nil {
logger.Errorf("failed setting up log file: %s, %s", logPath, err)
return err
}
writers = append(writers, writer)
}
}
@@ -94,17 +100,43 @@ func FindFirstLogPath(logs []string) string {
return ""
}
func isRotationDisabled(logger *log.Logger) bool {
v, _ := os.LookupEnv("NB_LOG_DISABLE_ROTATION")
disabled, _ := strconv.ParseBool(v)
if disabled {
logger.Warnf("log rotation is disabled by env flag")
return true
}
conflict, configPath := FindFirstLogrotateConflict()
if conflict {
logger.Warnf("log rotation conflict detected in: %#v, rotation is disabled", configPath)
return true
}
return false
}
func setupLogFile(logPath string, disableRotation bool) (io.Writer, error) {
if disableRotation {
file, err := os.OpenFile(logPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
if err != nil {
return nil, fmt.Errorf("failed opening log file: %s", err)
}
return file, nil
}
return newRotatedOutput(logPath), nil
}
func newRotatedOutput(logPath string) io.Writer {
maxLogSize := getLogMaxSize()
lumberjackLogger := &lumberjack.Logger{
timberjackLogger := &timberjack.Logger{
// Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath),
MaxSize: maxLogSize, // MB
MaxBackups: 10,
MaxAge: 30, // days
Compress: true,
Filename: filepath.ToSlash(logPath),
MaxSize: maxLogSize, // MB
MaxBackups: 10,
MaxAge: 30, // days
Compression: "gzip",
}
return lumberjackLogger
return timberjackLogger
}
func setGRPCLibLogger(logger *log.Logger) {

96
util/log_test.go Normal file
View File

@@ -0,0 +1,96 @@
package util
import (
"io"
"os"
"path/filepath"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
// TestSetupLogFile_RotatesOnSize drives >MaxSize bytes through the writer
// returned by setupLogFile and asserts a backup file appears.
func TestSetupLogFile_RotatesOnSize(t *testing.T) {
t.Setenv("NB_LOG_MAX_SIZE_MB", "1")
dir := t.TempDir()
logPath := filepath.Join(dir, "netbird.log")
w, err := setupLogFile(logPath, false)
require.NoError(t, err)
t.Cleanup(func() {
if c, ok := w.(io.Closer); ok {
_ = c.Close()
}
})
chunk := []byte(strings.Repeat("x", 64*1024) + "\n")
for range 20 {
_, err := w.Write(chunk)
require.NoError(t, err)
}
info, err := os.Stat(logPath)
require.NoError(t, err)
require.Less(t, info.Size(), int64(1<<20),
"active log should be < 1 MB after rotation, got %d", info.Size())
require.Eventually(t, func() bool {
entries, _ := os.ReadDir(dir)
for _, e := range entries {
name := e.Name()
if name == filepath.Base(logPath) {
continue
}
if strings.HasPrefix(name, "netbird-") && strings.HasSuffix(name, ".log.gz") {
return true
}
}
return false
}, 5*time.Second, 50*time.Millisecond, "expected a rotated backup file in %s", dir)
}
// TestSetupLogFile_RotationDisabled verifies that with rotation off, the file
// grows past MaxSize and no backups are created.
func TestSetupLogFile_RotationDisabled(t *testing.T) {
t.Setenv("NB_LOG_MAX_SIZE_MB", "1")
dir := t.TempDir()
logPath := filepath.Join(dir, "netbird.log")
w, err := setupLogFile(logPath, true)
require.NoError(t, err)
f, ok := w.(*os.File)
require.True(t, ok, "expected plain *os.File when rotation is disabled, got %T", w)
t.Cleanup(func() { _ = f.Close() })
chunk := []byte(strings.Repeat("y", 64*1024) + "\n")
for range 20 {
_, err := w.Write(chunk)
require.NoError(t, err)
}
info, err := os.Stat(logPath)
require.NoError(t, err)
require.GreaterOrEqual(t, info.Size(), int64(1<<20),
"file should exceed MaxSize when rotation is disabled, got %d", info.Size())
entries, err := os.ReadDir(dir)
require.NoError(t, err)
require.Len(t, entries, 1, "no backup files should exist when rotation is disabled, got %v", entries)
}
// TestIsRotationDisabled_EnvFlag covers the NB_LOG_DISABLE_ROTATION env path.
// The logrotate-conflict branch is exercised separately on linux.
func TestIsRotationDisabled_EnvFlag(t *testing.T) {
logger := log.New()
logger.SetOutput(io.Discard)
t.Setenv("NB_LOG_DISABLE_ROTATION", "true")
require.True(t, isRotationDisabled(logger))
}

93
util/logrotate_linux.go Normal file
View File

@@ -0,0 +1,93 @@
//go:build linux
package util
import (
"bufio"
"errors"
"io/fs"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
const (
defaultLogrotateConfPath = "/etc/logrotate.conf"
defaultLogrotateConfDir = "/etc/logrotate.d"
netbirdString = "netbird"
)
// FindLogrotateConflicts scans the standard logrotate locations for
// indications of conflict with netbird. It returns true and the config file
// path if a conflict was found.
func FindFirstLogrotateConflict() (bool, string) {
return findFirstLogrotateConflictIn(defaultLogrotateConfPath, defaultLogrotateConfDir)
}
func findFirstLogrotateConflictIn(confPath, confDir string) (bool, string) {
for _, f := range listLogrotateConfigs(confPath, confDir) {
present, err := scanLogrotateFile(f, netbirdString)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Debugf("scan %s: %v", f, err)
}
continue
}
if present {
return present, f
}
}
return false, ""
}
// listLogrotateConfigs returns all config files for logrotate.
func listLogrotateConfigs(confPath, confDir string) []string {
files := []string{confPath}
entries, err := os.ReadDir(confDir)
if err != nil {
return files
}
for _, e := range entries {
if e.IsDir() {
continue
}
files = append(files, filepath.Join(confDir, e.Name()))
}
return files
}
// scanLogrotateFile reads a config and reports if a non-comment line
// contains the given substring.
func scanLogrotateFile(path string, substring string) (bool, error) {
f, err := os.Open(path)
if err != nil {
return false, err
}
defer func() {
if err := f.Close(); err != nil {
log.Debugf("close %s: %v", path, err)
}
}()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(stripLogrotateComment(scanner.Text()))
if line == "" {
continue
}
if strings.Contains(line, substring) {
return true, nil
}
}
if err := scanner.Err(); err != nil {
return false, err
}
return false, nil
}
func stripLogrotateComment(line string) string {
before, _, _ := strings.Cut(line, "#")
return before
}

View File

@@ -0,0 +1,95 @@
//go:build linux
package util
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestFindFirstLogrotateConflict(t *testing.T) {
t.Run("conflict in confDir", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
conflictPath := filepath.Join(confDir, "netbird")
writeLogrotateConfig(t, conflictPath, `/var/log/netbird/*.log {
daily
rotate 7
}`)
writeLogrotateConfig(t, filepath.Join(confDir, "nginx"), `/var/log/nginx/*.log { daily }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.True(t, got)
require.Equal(t, conflictPath, path)
})
t.Run("conflict in main conf file", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
writeLogrotateConfig(t, confPath, `weekly
rotate 4
include /etc/logrotate.d
/var/log/netbird/client.log { rotate 5 }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.True(t, got)
require.Equal(t, confPath, path)
})
t.Run("no conflict when netbird is absent", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
writeLogrotateConfig(t, filepath.Join(confDir, "nginx"), `/var/log/nginx/*.log { daily }`)
writeLogrotateConfig(t, filepath.Join(confDir, "syslog"), `/var/log/syslog { weekly }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.False(t, got)
require.Empty(t, path)
})
t.Run("commented-out netbird line is ignored", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
writeLogrotateConfig(t, filepath.Join(confDir, "misc"), `# /var/log/netbird/*.log { daily }
/var/log/other.log { weekly }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.False(t, got)
require.Empty(t, path)
})
t.Run("subdirectories in confDir are ignored", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
sub := filepath.Join(confDir, "nested")
require.NoError(t, os.MkdirAll(sub, 0o755))
writeLogrotateConfig(t, filepath.Join(sub, "netbird"), `/var/log/netbird/*.log { daily }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.False(t, got)
require.Empty(t, path)
})
t.Run("missing paths return no conflict", func(t *testing.T) {
dir := t.TempDir()
got, path := findFirstLogrotateConflictIn(
filepath.Join(dir, "does-not-exist.conf"),
filepath.Join(dir, "does-not-exist.d"),
)
require.False(t, got)
require.Empty(t, path)
})
}
// newLogrotateLayout creates a temp logrotate.conf path and logrotate.d dir,
// returning their paths. The conf file itself is not created.
func newLogrotateLayout(t *testing.T) (confPath, confDir string) {
t.Helper()
root := t.TempDir()
confDir = filepath.Join(root, "logrotate.d")
require.NoError(t, os.MkdirAll(confDir, 0o755))
return filepath.Join(root, "logrotate.conf"), confDir
}
func writeLogrotateConfig(t *testing.T, path, body string) {
t.Helper()
require.NoError(t, os.WriteFile(path, []byte(body), 0o644))
}

View File

@@ -0,0 +1,10 @@
//go:build !linux
package util
// FindLogrotateConflicts scans the standard logrotate locations for
// indications of conflict with netbird. It will always return false for
// non-linux devices.
func FindFirstLogrotateConflict() (bool, string) {
return false, ""
}