mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-17 14:19:54 +00:00
Merge remote-tracking branch 'origin/main' into fix/login-cmd-root-flags
This commit is contained in:
196
client/cmd/capture.go
Normal file
196
client/cmd/capture.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
var captureCmd = &cobra.Command{
|
||||
Use: "capture",
|
||||
Short: "Capture packets on the WireGuard interface",
|
||||
Long: `Captures decrypted packets flowing through the WireGuard interface.
|
||||
|
||||
Default output is human-readable text. Use --pcap or --output for pcap binary.
|
||||
Requires --enable-capture to be set at service install or reconfigure time.
|
||||
|
||||
Examples:
|
||||
netbird debug capture
|
||||
netbird debug capture host 100.64.0.1 and port 443
|
||||
netbird debug capture tcp
|
||||
netbird debug capture icmp
|
||||
netbird debug capture src host 10.0.0.1 and dst port 80
|
||||
netbird debug capture -o capture.pcap
|
||||
netbird debug capture --pcap | tshark -r -
|
||||
netbird debug capture --pcap | tcpdump -r - -n`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
RunE: runCapture,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.AddCommand(captureCmd)
|
||||
|
||||
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
|
||||
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
|
||||
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
|
||||
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
|
||||
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
|
||||
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
|
||||
}
|
||||
|
||||
func runCapture(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
cmd.PrintErrf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
req, err := buildCaptureRequest(cmd, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.StartCapture(ctx, req)
|
||||
if err != nil {
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
|
||||
// First Recv is the empty acceptance message from the server. If the
|
||||
// device is unavailable (kernel WG, not connected, capture disabled),
|
||||
// the server returns an error instead.
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
|
||||
out, cleanup, err := captureOutput(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if req.TextOutput {
|
||||
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
|
||||
} else {
|
||||
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
|
||||
}
|
||||
|
||||
streamErr := streamCapture(ctx, cmd, stream, out)
|
||||
cleanupErr := cleanup()
|
||||
if streamErr != nil {
|
||||
return streamErr
|
||||
}
|
||||
return cleanupErr
|
||||
}
|
||||
|
||||
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
|
||||
req := &proto.StartCaptureRequest{}
|
||||
|
||||
if len(args) > 0 {
|
||||
expr := strings.Join(args, " ")
|
||||
if _, err := capture.ParseFilter(expr); err != nil {
|
||||
return nil, fmt.Errorf("invalid filter: %w", err)
|
||||
}
|
||||
req.FilterExpr = expr
|
||||
}
|
||||
|
||||
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
|
||||
req.SnapLen = snap
|
||||
}
|
||||
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
|
||||
if d < 0 {
|
||||
return nil, fmt.Errorf("duration must not be negative")
|
||||
}
|
||||
req.Duration = durationpb.New(d)
|
||||
}
|
||||
req.Verbose, _ = cmd.Flags().GetBool("verbose")
|
||||
req.Ascii, _ = cmd.Flags().GetBool("ascii")
|
||||
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
req.TextOutput = !forcePcap && outPath == ""
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
|
||||
for {
|
||||
pkt, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.PrintErrf("\nCapture stopped.\n")
|
||||
return nil //nolint:nilerr // user interrupted
|
||||
}
|
||||
if err == io.EOF {
|
||||
cmd.PrintErrf("\nCapture finished.\n")
|
||||
return nil
|
||||
}
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
if _, err := out.Write(pkt.GetData()); err != nil {
|
||||
return fmt.Errorf("write output: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// captureOutput returns the writer for capture data and a cleanup function
|
||||
// that finalizes the file. Errors from the cleanup must be propagated.
|
||||
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
if outPath == "" {
|
||||
return os.Stdout, func() error { return nil }, nil
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create output file: %w", err)
|
||||
}
|
||||
tmpPath := f.Name()
|
||||
return f, func() error {
|
||||
var merr *multierror.Error
|
||||
if err := f.Close(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
|
||||
}
|
||||
fi, statErr := os.Stat(tmpPath)
|
||||
if statErr != nil || fi.Size() == 0 {
|
||||
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
if err := os.Rename(tmpPath, outPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
cmd.PrintErrf("Wrote %s\n", outPath)
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func handleCaptureError(err error) error {
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return fmt.Errorf("%s", s.Message())
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
@@ -98,7 +98,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -136,6 +135,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
if level == proto.LogLevel_UNKNOWN {
|
||||
//nolint
|
||||
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
|
||||
}
|
||||
|
||||
@@ -182,10 +182,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
cmd.Println("netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
@@ -199,10 +200,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
needsRestoreUp := false
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = !stateWasDown
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
cmd.Println("netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@@ -210,31 +214,88 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = false
|
||||
cmd.Println("netbird up")
|
||||
}
|
||||
cmd.Println("netbird up")
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||
cpuProfilingStarted := false
|
||||
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
|
||||
} else {
|
||||
cpuProfilingStarted = true
|
||||
defer func() {
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
captureStarted := false
|
||||
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
|
||||
captureTimeout := duration + 30*time.Second
|
||||
const maxBundleCapture = 10 * time.Minute
|
||||
if captureTimeout > maxBundleCapture {
|
||||
captureTimeout = maxBundleCapture
|
||||
}
|
||||
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
|
||||
Timeout: durationpb.New(captureTimeout),
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
captureStarted = true
|
||||
cmd.Println("Packet capture started.")
|
||||
// Safety: always stop on exit, even if the normal stop below runs too.
|
||||
defer func() {
|
||||
if captureStarted {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
if captureStarted {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
||||
} else {
|
||||
captureStarted = false
|
||||
cmd.Println("Packet capture stopped.")
|
||||
}
|
||||
}
|
||||
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
} else {
|
||||
cpuProfilingStarted = false
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -246,18 +307,28 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if needsRestoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up (restored)")
|
||||
}
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
||||
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
|
||||
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
@@ -301,25 +372,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context(), true)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -378,7 +430,8 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogFile: logFilePath,
|
||||
LogPath: logFilePath,
|
||||
CPUProfile: nil,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
@@ -403,4 +456,5 @@ func init() {
|
||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
|
||||
}
|
||||
|
||||
287
client/cmd/expose.go
Normal file
287
client/cmd/expose.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||
|
||||
var (
|
||||
exposePin string
|
||||
exposePassword string
|
||||
exposeUserGroups []string
|
||||
exposeDomain string
|
||||
exposeNamePrefix string
|
||||
exposeProtocol string
|
||||
exposeExternalPort uint16
|
||||
)
|
||||
|
||||
var exposeCmd = &cobra.Command{
|
||||
Use: "expose <port>",
|
||||
Short: "Expose a local port via the NetBird reverse proxy",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Example: ` netbird expose --with-password safe-pass 8080
|
||||
netbird expose --protocol tcp 5432
|
||||
netbird expose --protocol tcp --with-external-port 5433 5432
|
||||
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
|
||||
RunE: exposeFn,
|
||||
}
|
||||
|
||||
func init() {
|
||||
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
|
||||
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
|
||||
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
|
||||
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
|
||||
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
|
||||
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
|
||||
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
|
||||
}
|
||||
|
||||
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
|
||||
func isClusterProtocol(protocol string) bool {
|
||||
switch strings.ToLower(protocol) {
|
||||
case "tcp", "udp", "tls":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
|
||||
// where domain display doesn't apply. TLS uses SNI so it has a domain.
|
||||
func isPortBasedProtocol(protocol string) bool {
|
||||
switch strings.ToLower(protocol) {
|
||||
case "tcp", "udp":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// extractPort returns the port portion of a URL like "tcp://host:12345", or
|
||||
// falls back to the given default formatted as a string.
|
||||
func extractPort(serviceURL string, fallback uint16) string {
|
||||
u := serviceURL
|
||||
if idx := strings.Index(u, "://"); idx != -1 {
|
||||
u = u[idx+3:]
|
||||
}
|
||||
if i := strings.LastIndex(u, ":"); i != -1 {
|
||||
if p := u[i+1:]; p != "" {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return strconv.FormatUint(uint64(fallback), 10)
|
||||
}
|
||||
|
||||
// resolveExternalPort returns the effective external port, defaulting to the target port.
|
||||
func resolveExternalPort(targetPort uint64) uint16 {
|
||||
if exposeExternalPort != 0 {
|
||||
return exposeExternalPort
|
||||
}
|
||||
return uint16(targetPort)
|
||||
}
|
||||
|
||||
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
|
||||
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
||||
}
|
||||
if port == 0 || port > 65535 {
|
||||
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
|
||||
}
|
||||
|
||||
if !isProtocolValid(exposeProtocol) {
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
|
||||
}
|
||||
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
|
||||
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
|
||||
}
|
||||
} else if cmd.Flags().Changed("with-external-port") {
|
||||
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
|
||||
}
|
||||
|
||||
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
|
||||
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-password") && exposePassword == "" {
|
||||
return 0, fmt.Errorf("password cannot be empty")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
|
||||
return 0, fmt.Errorf("user groups cannot be empty")
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func isProtocolValid(exposeProtocol string) bool {
|
||||
switch strings.ToLower(exposeProtocol) {
|
||||
case "http", "https", "tcp", "udp", "tls":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||
log.Errorf("failed initializing log %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = false
|
||||
|
||||
port, err := validateExposeFlags(cmd, args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = true
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigCh
|
||||
cancel()
|
||||
}()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close daemon connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
protocol, err := toExposeProtocol(exposeProtocol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &proto.ExposeServiceRequest{
|
||||
Port: uint32(port),
|
||||
Protocol: protocol,
|
||||
Pin: exposePin,
|
||||
Password: exposePassword,
|
||||
UserGroups: exposeUserGroups,
|
||||
Domain: exposeDomain,
|
||||
NamePrefix: exposeNamePrefix,
|
||||
}
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
req.ListenPort = uint32(resolveExternalPort(port))
|
||||
}
|
||||
|
||||
stream, err := client.ExposeService(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return waitForExposeEvents(cmd, ctx, stream)
|
||||
}
|
||||
|
||||
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
p, err := expose.ParseProtocolType(exposeProtocol)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid protocol: %w", err)
|
||||
}
|
||||
|
||||
switch p {
|
||||
case expose.ProtocolHTTP:
|
||||
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||
case expose.ProtocolHTTPS:
|
||||
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||
case expose.ProtocolTCP:
|
||||
return proto.ExposeProtocol_EXPOSE_TCP, nil
|
||||
case expose.ProtocolUDP:
|
||||
return proto.ExposeProtocol_EXPOSE_UDP, nil
|
||||
case expose.ProtocolTLS:
|
||||
return proto.ExposeProtocol_EXPOSE_TLS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unhandled protocol type: %d", p)
|
||||
}
|
||||
}
|
||||
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected expose event: %T", event.Event)
|
||||
}
|
||||
printExposeReady(cmd, ready.Ready, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
|
||||
cmd.Println("Service exposed successfully!")
|
||||
cmd.Printf(" Name: %s\n", r.ServiceName)
|
||||
if r.ServiceUrl != "" {
|
||||
cmd.Printf(" URL: %s\n", r.ServiceUrl)
|
||||
}
|
||||
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
|
||||
cmd.Printf(" Domain: %s\n", r.Domain)
|
||||
}
|
||||
cmd.Printf(" Protocol: %s\n", exposeProtocol)
|
||||
cmd.Printf(" Internal: %d\n", port)
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
|
||||
}
|
||||
if r.PortAutoAssigned && exposeExternalPort != 0 {
|
||||
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
|
||||
}
|
||||
cmd.Println()
|
||||
cmd.Println("Press Ctrl+C to stop exposing.")
|
||||
}
|
||||
|
||||
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
|
||||
for {
|
||||
_, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.Println("\nService stopped.")
|
||||
//nolint:nilerr
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("connection to daemon closed unexpectedly")
|
||||
}
|
||||
return fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
}
|
||||
@@ -81,6 +82,7 @@ var loginCmd = &cobra.Command{
|
||||
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -224,6 +226,7 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
|
||||
func switchProfile(ctx context.Context, profileName string, username string) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -273,7 +276,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
}
|
||||
|
||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
|
||||
|
||||
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -293,18 +296,15 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, "", "")
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check login required: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -316,23 +316,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -358,13 +344,9 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
@@ -372,7 +354,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
@@ -386,6 +368,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
if showQR {
|
||||
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
|
||||
printQRCode(f, verificationURIComplete)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
|
||||
25
client/cmd/qr.go
Normal file
25
client/cmd/qr.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
)
|
||||
|
||||
// printQRCode prints a QR code for the given URL to the writer.
|
||||
// Called only when the user explicitly requests QR output via --qr.
|
||||
func printQRCode(w io.Writer, url string) {
|
||||
if url == "" {
|
||||
return
|
||||
}
|
||||
qrterminal.GenerateWithConfig(url, qrterminal.Config{
|
||||
Level: qrterminal.M,
|
||||
Writer: w,
|
||||
HalfBlocks: true,
|
||||
BlackChar: qrterminal.BLACK_BLACK,
|
||||
WhiteChar: qrterminal.WHITE_WHITE,
|
||||
BlackWhiteChar: qrterminal.BLACK_WHITE,
|
||||
WhiteBlackChar: qrterminal.WHITE_BLACK,
|
||||
QuietZone: qrterminal.QUIET_ZONE,
|
||||
})
|
||||
}
|
||||
26
client/cmd/qr_test.go
Normal file
26
client/cmd/qr_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrintQRCode_EmptyURL(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "")
|
||||
|
||||
if buf.Len() != 0 {
|
||||
t.Error("expected no output for empty URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintQRCode_WritesOutput(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "https://example.com/auth")
|
||||
|
||||
if buf.Len() == 0 {
|
||||
t.Error("expected QR code output for non-empty URL")
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -74,12 +75,23 @@ var (
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
Long: "",
|
||||
SilenceUsage: true,
|
||||
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(cmd.Root())
|
||||
|
||||
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||
if !isServiceCmd(cmd) {
|
||||
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -144,6 +156,7 @@ func init() {
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
rootCmd.AddCommand(profileCmd)
|
||||
rootCmd.AddCommand(exposeCmd)
|
||||
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
@@ -385,11 +398,11 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
}
|
||||
|
||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -397,3 +410,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||
func isServiceCmd(cmd *cobra.Command) bool {
|
||||
for c := cmd; c != nil; c = c.Parent() {
|
||||
if c.Name() == "service" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -41,13 +41,17 @@ func init() {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
|
||||
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
|
||||
`Use --service-env "" to clear all saved env vars. ` +
|
||||
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
}
|
||||
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled)
|
||||
if err := serverInstance.Start(); err != nil {
|
||||
log.Fatalf("failed to start daemon: %v", err)
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
|
||||
|
||||
// Common setup for service control commands
|
||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
// rootCmd env vars are already applied by PersistentPreRunE.
|
||||
SetFlagsFromEnvVars(serviceCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -59,6 +59,14 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-update-settings")
|
||||
}
|
||||
|
||||
if captureEnabled {
|
||||
args = append(args, "--enable-capture")
|
||||
}
|
||||
|
||||
if networksDisabled {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
@@ -119,6 +127,10 @@ var installCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
svcConfig, err := createServiceConfigForInstall()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -136,6 +148,10 @@ var installCmd = &cobra.Command{
|
||||
return fmt.Errorf("install service: %w", err)
|
||||
}
|
||||
|
||||
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||
}
|
||||
|
||||
cmd.Println("NetBird service has been installed")
|
||||
return nil
|
||||
},
|
||||
@@ -187,6 +203,10 @@ This command will temporarily stop the service, update its configuration, and re
|
||||
return err
|
||||
}
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
wasRunning, err := isServiceRunning()
|
||||
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
||||
return fmt.Errorf("check service status: %w", err)
|
||||
@@ -222,6 +242,10 @@ This command will temporarily stop the service, update its configuration, and re
|
||||
return fmt.Errorf("install service with new config: %w", err)
|
||||
}
|
||||
|
||||
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||
}
|
||||
|
||||
if wasRunning {
|
||||
cmd.Println("Starting NetBird service...")
|
||||
if err := s.Start(); err != nil {
|
||||
|
||||
224
client/cmd/service_params.go
Normal file
224
client/cmd/service_params.go
Normal file
@@ -0,0 +1,224 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const serviceParamsFile = "service.json"
|
||||
|
||||
// serviceParams holds install-time service parameters that persist across
|
||||
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
|
||||
type serviceParams struct {
|
||||
LogLevel string `json:"log_level"`
|
||||
DaemonAddr string `json:"daemon_addr"`
|
||||
ManagementURL string `json:"management_url,omitempty"`
|
||||
ConfigPath string `json:"config_path,omitempty"`
|
||||
LogFiles []string `json:"log_files,omitempty"`
|
||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
|
||||
// serviceParamsPath returns the path to the service params file.
|
||||
func serviceParamsPath() string {
|
||||
return filepath.Join(configs.StateDir, serviceParamsFile)
|
||||
}
|
||||
|
||||
// loadServiceParams reads saved service parameters from disk.
|
||||
// Returns nil with no error if the file does not exist.
|
||||
func loadServiceParams() (*serviceParams, error) {
|
||||
path := serviceParamsPath()
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return nil, fmt.Errorf("read service params %s: %w", path, err)
|
||||
}
|
||||
|
||||
var params serviceParams
|
||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||
return nil, fmt.Errorf("parse service params %s: %w", path, err)
|
||||
}
|
||||
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
// saveServiceParams writes current service parameters to disk atomically
|
||||
// with restricted permissions.
|
||||
func saveServiceParams(params *serviceParams) error {
|
||||
path := serviceParamsPath()
|
||||
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
|
||||
return fmt.Errorf("save service params: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// currentServiceParams captures the current state of all package-level
|
||||
// variables into a serviceParams struct.
|
||||
func currentServiceParams() *serviceParams {
|
||||
params := &serviceParams{
|
||||
LogLevel: logLevel,
|
||||
DaemonAddr: daemonAddr,
|
||||
ManagementURL: managementURL,
|
||||
ConfigPath: configPath,
|
||||
LogFiles: logFiles,
|
||||
DisableProfiles: profilesDisabled,
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
EnableCapture: captureEnabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err == nil {
|
||||
params.ServiceEnvVars = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// loadAndApplyServiceParams loads saved params from disk and applies them
|
||||
// to any flags that were not explicitly set.
|
||||
func loadAndApplyServiceParams(cmd *cobra.Command) error {
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
applyServiceParams(cmd, params)
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyServiceParams merges saved parameters into package-level variables
|
||||
// for any flag that was not explicitly set by the user (via CLI or env var).
|
||||
// Flags that were Changed() are left untouched.
|
||||
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For fields with non-empty defaults (log-level, daemon-addr), keep the
|
||||
// != "" guard so that an older service.json missing the field doesn't
|
||||
// clobber the default with an empty string.
|
||||
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
|
||||
logLevel = params.LogLevel
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
|
||||
daemonAddr = params.DaemonAddr
|
||||
}
|
||||
|
||||
// For optional fields where empty means "use default", always apply so
|
||||
// that an explicit clear (--management-url "") persists across reinstalls.
|
||||
if !rootCmd.PersistentFlags().Changed("management-url") {
|
||||
managementURL = params.ManagementURL
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("config") {
|
||||
configPath = params.ConfigPath
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("log-file") {
|
||||
logFiles = params.LogFiles
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
|
||||
profilesDisabled = params.DisableProfiles
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
|
||||
updateSettingsDisabled = params.DisableUpdateSettings
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
|
||||
captureEnabled = params.EnableCapture
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, params)
|
||||
}
|
||||
|
||||
// applyServiceEnvParams merges saved service environment variables.
|
||||
// If --service-env was explicitly set with values, explicit values win on key
|
||||
// conflict but saved keys not in the explicit set are carried over.
|
||||
// If --service-env was explicitly set to empty, all saved env vars are cleared.
|
||||
// If --service-env was not set, saved env vars are used entirely.
|
||||
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||
if !cmd.Flags().Changed("service-env") {
|
||||
if len(params.ServiceEnvVars) > 0 {
|
||||
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Flag was explicitly set: parse what the user provided.
|
||||
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the user passed an empty value (e.g. --service-env ""), clear all
|
||||
// saved env vars rather than merging.
|
||||
if len(explicit) == 0 {
|
||||
serviceEnvVars = nil
|
||||
return
|
||||
}
|
||||
|
||||
if len(params.ServiceEnvVars) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Merge saved values underneath explicit ones.
|
||||
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||
maps.Copy(merged, params.ServiceEnvVars)
|
||||
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||
serviceEnvVars = envMapToSlice(merged)
|
||||
}
|
||||
|
||||
var resetParamsCmd = &cobra.Command{
|
||||
Use: "reset-params",
|
||||
Short: "Remove saved service install parameters",
|
||||
Long: "Removes the saved service.json file so the next install uses default parameters.",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
path := serviceParamsPath()
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
cmd.Println("No saved service parameters found")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("remove service params: %w", err)
|
||||
}
|
||||
cmd.Printf("Removed saved service parameters (%s)\n", path)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
|
||||
func envMapToSlice(m map[string]string) []string {
|
||||
s := make([]string, 0, len(m))
|
||||
for k, v := range m {
|
||||
s = append(s, k+"="+v)
|
||||
}
|
||||
return s
|
||||
}
|
||||
560
client/cmd/service_params_test.go
Normal file
560
client/cmd/service_params_test.go
Normal file
@@ -0,0 +1,560 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
)
|
||||
|
||||
func TestServiceParamsPath(t *testing.T) {
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
|
||||
configs.StateDir = "/var/lib/netbird"
|
||||
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
|
||||
|
||||
configs.StateDir = "/custom/state"
|
||||
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
|
||||
}
|
||||
|
||||
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
params := &serviceParams{
|
||||
LogLevel: "debug",
|
||||
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||
ManagementURL: "https://my.server.com",
|
||||
ConfigPath: "/etc/netbird/config.json",
|
||||
LogFiles: []string{"/var/log/netbird/client.log", "console"},
|
||||
DisableProfiles: true,
|
||||
DisableUpdateSettings: false,
|
||||
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
|
||||
}
|
||||
|
||||
err := saveServiceParams(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file exists and is valid JSON.
|
||||
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, json.Valid(data))
|
||||
|
||||
loaded, err := loadServiceParams()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loaded)
|
||||
|
||||
assert.Equal(t, params.LogLevel, loaded.LogLevel)
|
||||
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
|
||||
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
|
||||
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
|
||||
assert.Equal(t, params.LogFiles, loaded.LogFiles)
|
||||
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
|
||||
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
|
||||
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
|
||||
}
|
||||
|
||||
func TestLoadServiceParams_FileNotExists(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
params, err := loadServiceParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, params)
|
||||
}
|
||||
|
||||
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
params, err := loadServiceParams()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, params)
|
||||
}
|
||||
|
||||
func TestCurrentServiceParams(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
origDaemonAddr := daemonAddr
|
||||
origManagementURL := managementURL
|
||||
origConfigPath := configPath
|
||||
origLogFiles := logFiles
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() {
|
||||
logLevel = origLogLevel
|
||||
daemonAddr = origDaemonAddr
|
||||
managementURL = origManagementURL
|
||||
configPath = origConfigPath
|
||||
logFiles = origLogFiles
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
serviceEnvVars = origServiceEnvVars
|
||||
})
|
||||
|
||||
logLevel = "trace"
|
||||
daemonAddr = "tcp://127.0.0.1:9999"
|
||||
managementURL = "https://mgmt.example.com"
|
||||
configPath = "/tmp/test-config.json"
|
||||
logFiles = []string{"/tmp/test.log"}
|
||||
profilesDisabled = true
|
||||
updateSettingsDisabled = true
|
||||
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
|
||||
|
||||
params := currentServiceParams()
|
||||
|
||||
assert.Equal(t, "trace", params.LogLevel)
|
||||
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
|
||||
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
|
||||
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
|
||||
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
|
||||
assert.True(t, params.DisableProfiles)
|
||||
assert.True(t, params.DisableUpdateSettings)
|
||||
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
origDaemonAddr := daemonAddr
|
||||
origManagementURL := managementURL
|
||||
origConfigPath := configPath
|
||||
origLogFiles := logFiles
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() {
|
||||
logLevel = origLogLevel
|
||||
daemonAddr = origDaemonAddr
|
||||
managementURL = origManagementURL
|
||||
configPath = origConfigPath
|
||||
logFiles = origLogFiles
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
serviceEnvVars = origServiceEnvVars
|
||||
})
|
||||
|
||||
// Reset all flags to defaults.
|
||||
logLevel = "info"
|
||||
daemonAddr = "unix:///var/run/netbird.sock"
|
||||
managementURL = ""
|
||||
configPath = "/etc/netbird/config.json"
|
||||
logFiles = []string{"/var/log/netbird/client.log"}
|
||||
profilesDisabled = false
|
||||
updateSettingsDisabled = false
|
||||
serviceEnvVars = nil
|
||||
|
||||
// Reset Changed state on all relevant flags.
|
||||
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
// Simulate user explicitly setting --log-level via CLI.
|
||||
logLevel = "warn"
|
||||
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
|
||||
|
||||
saved := &serviceParams{
|
||||
LogLevel: "debug",
|
||||
DaemonAddr: "tcp://127.0.0.1:5555",
|
||||
ManagementURL: "https://saved.example.com",
|
||||
ConfigPath: "/saved/config.json",
|
||||
LogFiles: []string{"/saved/client.log"},
|
||||
DisableProfiles: true,
|
||||
DisableUpdateSettings: true,
|
||||
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
// log-level was Changed, so it should keep "warn", not use saved "debug".
|
||||
assert.Equal(t, "warn", logLevel)
|
||||
|
||||
// All other fields were not Changed, so they should use saved values.
|
||||
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
|
||||
assert.Equal(t, "https://saved.example.com", managementURL)
|
||||
assert.Equal(t, "/saved/config.json", configPath)
|
||||
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
|
||||
assert.True(t, profilesDisabled)
|
||||
assert.True(t, updateSettingsDisabled)
|
||||
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
t.Cleanup(func() {
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
})
|
||||
|
||||
// Simulate current state where booleans are true (e.g. set by previous install).
|
||||
profilesDisabled = true
|
||||
updateSettingsDisabled = true
|
||||
|
||||
// Reset Changed state so flags appear unset.
|
||||
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
// Saved params have both as false.
|
||||
saved := &serviceParams{
|
||||
DisableProfiles: false,
|
||||
DisableUpdateSettings: false,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
assert.False(t, profilesDisabled, "saved false should override current true")
|
||||
assert.False(t, updateSettingsDisabled, "saved false should override current true")
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
|
||||
origManagementURL := managementURL
|
||||
t.Cleanup(func() { managementURL = origManagementURL })
|
||||
|
||||
managementURL = "https://leftover.example.com"
|
||||
|
||||
// Simulate saved params where management URL was explicitly cleared.
|
||||
saved := &serviceParams{
|
||||
LogLevel: "info",
|
||||
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||
// ManagementURL intentionally empty: was cleared with --management-url "".
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_NilParams(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
t.Cleanup(func() { logLevel = origLogLevel })
|
||||
|
||||
logLevel = "info"
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
|
||||
// Should be a no-op.
|
||||
applyServiceParams(cmd, nil)
|
||||
assert.Equal(t, "info", logLevel)
|
||||
}
|
||||
|
||||
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
// Set up a command with --service-env marked as Changed.
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
|
||||
|
||||
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
|
||||
|
||||
saved := &serviceParams{
|
||||
ServiceEnvVars: map[string]string{
|
||||
"SAVED": "val",
|
||||
"OVERLAP": "saved",
|
||||
},
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, saved)
|
||||
|
||||
// Parse result for easier assertion.
|
||||
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "yes", result["EXPLICIT"])
|
||||
assert.Equal(t, "val", result["SAVED"])
|
||||
// Explicit wins on conflict.
|
||||
assert.Equal(t, "explicit", result["OVERLAP"])
|
||||
}
|
||||
|
||||
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
serviceEnvVars = nil
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
|
||||
saved := &serviceParams{
|
||||
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, saved)
|
||||
|
||||
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||
}
|
||||
|
||||
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
// Simulate --service-env "" which produces [""] in the slice.
|
||||
serviceEnvVars = []string{""}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
require.NoError(t, cmd.Flags().Set("service-env", ""))
|
||||
|
||||
saved := &serviceParams{
|
||||
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, saved)
|
||||
|
||||
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
|
||||
}
|
||||
|
||||
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
// Simulate --service-env "" which produces [""] in the slice.
|
||||
serviceEnvVars = []string{""}
|
||||
|
||||
params := currentServiceParams()
|
||||
|
||||
// After parsing, the empty string is skipped, resulting in an empty map.
|
||||
// The map should still be set (not nil) so it overwrites saved values.
|
||||
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
|
||||
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
|
||||
}
|
||||
|
||||
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||
// added to serviceParams but not wired into these functions, this test fails.
|
||||
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Collect all JSON field names from the serviceParams struct.
|
||||
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
|
||||
|
||||
// Collect field names referenced in currentServiceParams and applyServiceParams.
|
||||
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
|
||||
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
|
||||
// applyServiceEnvParams handles ServiceEnvVars indirectly.
|
||||
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
|
||||
for k, v := range applyEnvFields {
|
||||
applyFields[k] = v
|
||||
}
|
||||
|
||||
for _, field := range structFields {
|
||||
assert.Contains(t, currentFields, field,
|
||||
"serviceParams field %q is not captured in currentServiceParams()", field)
|
||||
assert.Contains(t, applyFields, field,
|
||||
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
|
||||
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
|
||||
// it flows through newSVCConfig() EnvVars, not CLI args.
|
||||
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||
require.NotEmpty(t, structFields)
|
||||
|
||||
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
|
||||
fieldsNotInArgs := map[string]bool{
|
||||
"ServiceEnvVars": true,
|
||||
}
|
||||
|
||||
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
|
||||
|
||||
// Forward: every struct field must appear in buildServiceArguments.
|
||||
for _, field := range structFields {
|
||||
if fieldsNotInArgs[field] {
|
||||
continue
|
||||
}
|
||||
globalVar := fieldToGlobalVar(field)
|
||||
assert.Contains(t, buildFields, globalVar,
|
||||
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
|
||||
}
|
||||
|
||||
// Reverse: every service-related global used in buildServiceArguments must
|
||||
// have a corresponding serviceParams field. This catches a developer adding
|
||||
// a new flag to buildServiceArguments without adding it to the struct.
|
||||
globalToField := make(map[string]string, len(structFields))
|
||||
for _, field := range structFields {
|
||||
globalToField[fieldToGlobalVar(field)] = field
|
||||
}
|
||||
// Identifiers in buildServiceArguments that are not service params
|
||||
// (builtins, boilerplate, loop variables).
|
||||
nonParamGlobals := map[string]bool{
|
||||
"args": true, "append": true, "string": true, "_": true,
|
||||
"logFile": true, // range variable over logFiles
|
||||
}
|
||||
for ref := range buildFields {
|
||||
if nonParamGlobals[ref] {
|
||||
continue
|
||||
}
|
||||
_, inStruct := globalToField[ref]
|
||||
assert.True(t, inStruct,
|
||||
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
|
||||
}
|
||||
}
|
||||
|
||||
// extractStructJSONFields returns field names from a named struct type.
|
||||
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
|
||||
t.Helper()
|
||||
var fields []string
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
ts, ok := n.(*ast.TypeSpec)
|
||||
if !ok || ts.Name.Name != structName {
|
||||
return true
|
||||
}
|
||||
st, ok := ts.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, f := range st.Fields.List {
|
||||
if len(f.Names) > 0 {
|
||||
fields = append(fields, f.Names[0].Name)
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return fields
|
||||
}
|
||||
|
||||
// extractFuncFieldRefs returns which of the given field names appear inside the
|
||||
// named function, either as selector expressions (params.FieldName) or as
|
||||
// composite literal keys (&serviceParams{FieldName: ...}).
|
||||
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
|
||||
t.Helper()
|
||||
fieldSet := make(map[string]bool, len(fields))
|
||||
for _, f := range fields {
|
||||
fieldSet[f] = true
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
fn := findFuncDecl(file, funcName)
|
||||
require.NotNil(t, fn, "function %s not found", funcName)
|
||||
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
switch v := n.(type) {
|
||||
case *ast.SelectorExpr:
|
||||
if fieldSet[v.Sel.Name] {
|
||||
found[v.Sel.Name] = true
|
||||
}
|
||||
case *ast.KeyValueExpr:
|
||||
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
|
||||
found[ident.Name] = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
|
||||
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
|
||||
t.Helper()
|
||||
fn := findFuncDecl(file, funcName)
|
||||
require.NotNil(t, fn, "function %s not found", funcName)
|
||||
|
||||
refs := make(map[string]bool)
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
if ident, ok := n.(*ast.Ident); ok {
|
||||
refs[ident.Name] = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
return refs
|
||||
}
|
||||
|
||||
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
|
||||
for _, decl := range file.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
if ok && fn.Name.Name == name {
|
||||
return fn
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fieldToGlobalVar maps serviceParams field names to the package-level variable
|
||||
// names used in buildServiceArguments and applyServiceParams.
|
||||
func fieldToGlobalVar(field string) string {
|
||||
m := map[string]string{
|
||||
"LogLevel": "logLevel",
|
||||
"DaemonAddr": "daemonAddr",
|
||||
"ManagementURL": "managementURL",
|
||||
"ConfigPath": "configPath",
|
||||
"LogFiles": "logFiles",
|
||||
"DisableProfiles": "profilesDisabled",
|
||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||
"EnableCapture": "captureEnabled",
|
||||
"DisableNetworks": "networksDisabled",
|
||||
"ServiceEnvVars": "serviceEnvVars",
|
||||
}
|
||||
if v, ok := m[field]; ok {
|
||||
return v
|
||||
}
|
||||
// Default: lowercase first letter.
|
||||
return strings.ToLower(field[:1]) + field[1:]
|
||||
}
|
||||
|
||||
func TestEnvMapToSlice(t *testing.T) {
|
||||
m := map[string]string{"A": "1", "B": "2"}
|
||||
s := envMapToSlice(m)
|
||||
assert.Len(t, s, 2)
|
||||
assert.Contains(t, s, "A=1")
|
||||
assert.Contains(t, s, "B=2")
|
||||
}
|
||||
|
||||
func TestEnvMapToSlice_Empty(t *testing.T) {
|
||||
s := envMapToSlice(map[string]string{})
|
||||
assert.Empty(t, s)
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +15,22 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMain intercepts when this test binary is run as a daemon subprocess.
|
||||
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
|
||||
// "service run ..." arguments. Since the test binary can't handle cobra CLI
|
||||
// args, it exits immediately, causing daemon -r to respawn rapidly until
|
||||
// hitting the rate limit and exiting. This makes service restart unreliable.
|
||||
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
|
||||
<-sig
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||
if err := validateDestinationPort(remoteAddr); err != nil {
|
||||
return fmt.Errorf("invalid remote address: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
@@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||
if err := validateDestinationPort(localAddr); err != nil {
|
||||
return fmt.Errorf("invalid local address: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
@@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDestinationPort checks that the destination address has a valid port.
|
||||
// Port 0 is only valid for bind addresses (where the OS picks an available port),
|
||||
// not for destination addresses where we need to connect.
|
||||
func validateDestinationPort(addr string) error {
|
||||
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse address %s: %w", addr, err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
return fmt.Errorf("port 0 is not valid for destination address")
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return fmt.Errorf("port %d out of range (1-65535)", port)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||
|
||||
@@ -28,6 +28,7 @@ var (
|
||||
ipsFilterMap map[string]struct{}
|
||||
prefixNamesFilterMap map[string]struct{}
|
||||
connectionTypeFilter string
|
||||
checkFlag string
|
||||
)
|
||||
|
||||
var statusCmd = &cobra.Command{
|
||||
@@ -49,6 +50,7 @@ func init() {
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -56,6 +58,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
if checkFlag != "" {
|
||||
return runHealthCheck(cmd)
|
||||
}
|
||||
|
||||
err := parseFilters()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -68,15 +74,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
resp, err := getStatus(ctx, false)
|
||||
resp, err := getStatus(ctx, true, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status := resp.GetStatus()
|
||||
|
||||
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||
status == string(internal.StatusSessionExpired) {
|
||||
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||
status == string(internal.StatusSessionExpired)
|
||||
|
||||
if needsAuth && !jsonFlag && !yamlFlag {
|
||||
cmd.Printf("Daemon status: %s\n\n"+
|
||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||
" netbird up \n\n"+
|
||||
@@ -99,17 +107,27 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
DaemonStatus: nbstatus.ParseDaemonStatus(status),
|
||||
StatusFilter: statusFilter,
|
||||
PrefixNamesFilter: prefixNamesFilter,
|
||||
PrefixNamesFilterMap: prefixNamesFilterMap,
|
||||
IPsFilter: ipsFilterMap,
|
||||
ConnectionTypeFilter: connectionTypeFilter,
|
||||
ProfileName: profName,
|
||||
})
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
statusOutputString = outputInformationHolder.FullDetailSummary()
|
||||
case jsonFlag:
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.JSON()
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.YAML()
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -121,16 +139,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -184,6 +203,83 @@ func enableDetailFlagWhenFilterFlag() {
|
||||
}
|
||||
}
|
||||
|
||||
func runHealthCheck(cmd *cobra.Command) error {
|
||||
check := strings.ToLower(checkFlag)
|
||||
switch check {
|
||||
case "live", "ready", "startup":
|
||||
default:
|
||||
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
|
||||
}
|
||||
|
||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
isStartup := check == "startup"
|
||||
resp, err := getStatus(ctx, isStartup, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch check {
|
||||
case "live":
|
||||
return nil
|
||||
case "ready":
|
||||
return checkReadiness(resp)
|
||||
case "startup":
|
||||
return checkStartup(resp)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkReadiness(resp *proto.StatusResponse) error {
|
||||
daemonStatus := internal.StatusType(resp.GetStatus())
|
||||
switch daemonStatus {
|
||||
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
|
||||
return nil
|
||||
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
|
||||
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
|
||||
default:
|
||||
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func checkStartup(resp *proto.StatusResponse) error {
|
||||
fullStatus := resp.GetFullStatus()
|
||||
if fullStatus == nil {
|
||||
return fmt.Errorf("startup check: no full status available")
|
||||
}
|
||||
|
||||
if !fullStatus.GetManagementState().GetConnected() {
|
||||
return fmt.Errorf("startup check: management not connected")
|
||||
}
|
||||
|
||||
if !fullStatus.GetSignalState().GetConnected() {
|
||||
return fmt.Errorf("startup check: signal not connected")
|
||||
}
|
||||
|
||||
var relayCount, relaysConnected int
|
||||
for _, r := range fullStatus.GetRelays() {
|
||||
uri := r.GetURI()
|
||||
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
|
||||
continue
|
||||
}
|
||||
relayCount++
|
||||
if r.GetAvailable() {
|
||||
relaysConnected++
|
||||
}
|
||||
}
|
||||
|
||||
if relayCount > 0 && relaysConnected == 0 {
|
||||
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseInterfaceIP(interfaceIP string) string {
|
||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,11 +13,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
@@ -89,9 +92,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
@@ -100,9 +100,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
@@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -127,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -152,7 +160,7 @@ func startClientDaemon(
|
||||
s := grpc.NewServer()
|
||||
|
||||
server := client.New(ctx,
|
||||
"", "", false, false)
|
||||
"", "", false, false, false, false)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,9 @@ const (
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
|
||||
showQRFlag = "qr"
|
||||
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
|
||||
|
||||
profileNameFlag = "profile"
|
||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||
)
|
||||
@@ -48,6 +51,7 @@ var (
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
showQR bool
|
||||
profileName string
|
||||
configPath string
|
||||
|
||||
@@ -80,6 +84,7 @@ func init() {
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||
|
||||
@@ -197,10 +202,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
@@ -216,6 +221,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user