mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-08 01:39:55 +00:00
Compare commits
15 Commits
v0.70.4
...
fix/login-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de527ad3a8 | ||
|
|
104990dfdd | ||
|
|
bde632c3b2 | ||
|
|
4268a5cfb7 | ||
|
|
a547fc74ed | ||
|
|
a21f6ecb0a | ||
|
|
6262b0d841 | ||
|
|
50b58a6828 | ||
|
|
057d651d2e | ||
|
|
c4b2da4c92 | ||
|
|
dcd1db42ef | ||
|
|
f29f5a0978 | ||
|
|
fda2146b48 | ||
|
|
ce79108d55 | ||
|
|
25a45d8ceb |
@@ -17,6 +17,7 @@ ENV \
|
||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENABLE_CAPTURE="false" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
@@ -23,6 +23,7 @@ ENV \
|
||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||
NB_DISABLE_DNS="true" \
|
||||
NB_ENABLE_CAPTURE="false" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
196
client/cmd/capture.go
Normal file
196
client/cmd/capture.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
var captureCmd = &cobra.Command{
|
||||
Use: "capture",
|
||||
Short: "Capture packets on the WireGuard interface",
|
||||
Long: `Captures decrypted packets flowing through the WireGuard interface.
|
||||
|
||||
Default output is human-readable text. Use --pcap or --output for pcap binary.
|
||||
Requires --enable-capture to be set at service install or reconfigure time.
|
||||
|
||||
Examples:
|
||||
netbird debug capture
|
||||
netbird debug capture host 100.64.0.1 and port 443
|
||||
netbird debug capture tcp
|
||||
netbird debug capture icmp
|
||||
netbird debug capture src host 10.0.0.1 and dst port 80
|
||||
netbird debug capture -o capture.pcap
|
||||
netbird debug capture --pcap | tshark -r -
|
||||
netbird debug capture --pcap | tcpdump -r - -n`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
RunE: runCapture,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.AddCommand(captureCmd)
|
||||
|
||||
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
|
||||
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
|
||||
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
|
||||
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
|
||||
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
|
||||
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
|
||||
}
|
||||
|
||||
func runCapture(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
cmd.PrintErrf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
req, err := buildCaptureRequest(cmd, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.StartCapture(ctx, req)
|
||||
if err != nil {
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
|
||||
// First Recv is the empty acceptance message from the server. If the
|
||||
// device is unavailable (kernel WG, not connected, capture disabled),
|
||||
// the server returns an error instead.
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
|
||||
out, cleanup, err := captureOutput(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if req.TextOutput {
|
||||
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
|
||||
} else {
|
||||
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
|
||||
}
|
||||
|
||||
streamErr := streamCapture(ctx, cmd, stream, out)
|
||||
cleanupErr := cleanup()
|
||||
if streamErr != nil {
|
||||
return streamErr
|
||||
}
|
||||
return cleanupErr
|
||||
}
|
||||
|
||||
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
|
||||
req := &proto.StartCaptureRequest{}
|
||||
|
||||
if len(args) > 0 {
|
||||
expr := strings.Join(args, " ")
|
||||
if _, err := capture.ParseFilter(expr); err != nil {
|
||||
return nil, fmt.Errorf("invalid filter: %w", err)
|
||||
}
|
||||
req.FilterExpr = expr
|
||||
}
|
||||
|
||||
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
|
||||
req.SnapLen = snap
|
||||
}
|
||||
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
|
||||
if d < 0 {
|
||||
return nil, fmt.Errorf("duration must not be negative")
|
||||
}
|
||||
req.Duration = durationpb.New(d)
|
||||
}
|
||||
req.Verbose, _ = cmd.Flags().GetBool("verbose")
|
||||
req.Ascii, _ = cmd.Flags().GetBool("ascii")
|
||||
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
req.TextOutput = !forcePcap && outPath == ""
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
|
||||
for {
|
||||
pkt, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.PrintErrf("\nCapture stopped.\n")
|
||||
return nil //nolint:nilerr // user interrupted
|
||||
}
|
||||
if err == io.EOF {
|
||||
cmd.PrintErrf("\nCapture finished.\n")
|
||||
return nil
|
||||
}
|
||||
return handleCaptureError(err)
|
||||
}
|
||||
if _, err := out.Write(pkt.GetData()); err != nil {
|
||||
return fmt.Errorf("write output: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// captureOutput returns the writer for capture data and a cleanup function
|
||||
// that finalizes the file. Errors from the cleanup must be propagated.
|
||||
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
if outPath == "" {
|
||||
return os.Stdout, func() error { return nil }, nil
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create output file: %w", err)
|
||||
}
|
||||
tmpPath := f.Name()
|
||||
return f, func() error {
|
||||
var merr *multierror.Error
|
||||
if err := f.Close(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
|
||||
}
|
||||
fi, statErr := os.Stat(tmpPath)
|
||||
if statErr != nil || fi.Size() == 0 {
|
||||
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
if err := os.Rename(tmpPath, outPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
cmd.PrintErrf("Wrote %s\n", outPath)
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func handleCaptureError(err error) error {
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return fmt.Errorf("%s", s.Message())
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
@@ -239,11 +240,50 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}()
|
||||
}
|
||||
|
||||
captureStarted := false
|
||||
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
|
||||
captureTimeout := duration + 30*time.Second
|
||||
const maxBundleCapture = 10 * time.Minute
|
||||
if captureTimeout > maxBundleCapture {
|
||||
captureTimeout = maxBundleCapture
|
||||
}
|
||||
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
|
||||
Timeout: durationpb.New(captureTimeout),
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
captureStarted = true
|
||||
cmd.Println("Packet capture started.")
|
||||
// Safety: always stop on exit, even if the normal stop below runs too.
|
||||
defer func() {
|
||||
if captureStarted {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
if captureStarted {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
||||
} else {
|
||||
captureStarted = false
|
||||
cmd.Println("Packet capture stopped.")
|
||||
}
|
||||
}
|
||||
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
@@ -416,4 +456,5 @@ func init() {
|
||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
}
|
||||
@@ -115,6 +117,24 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
// set the new config
|
||||
cfg, err := client.GetConfig(ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
Username: username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config from daemon: %v", err)
|
||||
}
|
||||
|
||||
req := setupSetConfigReqForLogin(cfg, activeProf.Name, username)
|
||||
if _, err := client.SetConfig(ctx, req); err != nil {
|
||||
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
|
||||
log.Warnf("setConfig method is not available in the daemon")
|
||||
} else {
|
||||
return fmt.Errorf("call service setConfig method: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
@@ -256,7 +276,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
}
|
||||
|
||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
|
||||
|
||||
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -324,7 +344,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
if err != nil {
|
||||
@@ -334,7 +354,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
@@ -348,6 +368,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
if showQR {
|
||||
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
|
||||
printQRCode(f, verificationURIComplete)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
@@ -378,3 +404,34 @@ func setEnvAndFlags(cmd *cobra.Command) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupSetConfigReqForLogin(cfg *proto.GetConfigResponse, profileName, username string) *proto.SetConfigRequest {
|
||||
var req proto.SetConfigRequest
|
||||
req.ProfileName = profileName
|
||||
req.Username = username
|
||||
|
||||
req.ManagementUrl = managementURL
|
||||
req.AdminURL = adminURL
|
||||
|
||||
req.RosenpassEnabled = &cfg.RosenpassEnabled
|
||||
req.RosenpassPermissive = &cfg.RosenpassPermissive
|
||||
req.DisableAutoConnect = &cfg.DisableAutoConnect
|
||||
req.ServerSSHAllowed = &cfg.ServerSSHAllowed
|
||||
req.NetworkMonitor = &cfg.NetworkMonitor
|
||||
req.DisableClientRoutes = &cfg.DisableClientRoutes
|
||||
req.DisableServerRoutes = &cfg.DisableServerRoutes
|
||||
req.DisableDns = &cfg.DisableDns
|
||||
req.DisableFirewall = &cfg.DisableFirewall
|
||||
req.BlockLanAccess = &cfg.BlockLanAccess
|
||||
req.DisableNotifications = &cfg.DisableNotifications
|
||||
req.LazyConnectionEnabled = &cfg.LazyConnectionEnabled
|
||||
req.BlockInbound = &cfg.BlockInbound
|
||||
req.DisableSSHAuth = &cfg.DisableSSHAuth
|
||||
req.EnableSSHRoot = &cfg.EnableSSHRoot
|
||||
req.EnableSSHSFTP = &cfg.EnableSSHSFTP
|
||||
req.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding
|
||||
req.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
|
||||
req.SshJWTCacheTTL = &cfg.SshJWTCacheTTL
|
||||
|
||||
return &req
|
||||
}
|
||||
|
||||
25
client/cmd/qr.go
Normal file
25
client/cmd/qr.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
)
|
||||
|
||||
// printQRCode prints a QR code for the given URL to the writer.
|
||||
// Called only when the user explicitly requests QR output via --qr.
|
||||
func printQRCode(w io.Writer, url string) {
|
||||
if url == "" {
|
||||
return
|
||||
}
|
||||
qrterminal.GenerateWithConfig(url, qrterminal.Config{
|
||||
Level: qrterminal.M,
|
||||
Writer: w,
|
||||
HalfBlocks: true,
|
||||
BlackChar: qrterminal.BLACK_BLACK,
|
||||
WhiteChar: qrterminal.WHITE_WHITE,
|
||||
BlackWhiteChar: qrterminal.BLACK_WHITE,
|
||||
WhiteBlackChar: qrterminal.WHITE_BLACK,
|
||||
QuietZone: qrterminal.QUIET_ZONE,
|
||||
})
|
||||
}
|
||||
26
client/cmd/qr_test.go
Normal file
26
client/cmd/qr_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrintQRCode_EmptyURL(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "")
|
||||
|
||||
if buf.Len() != 0 {
|
||||
t.Error("expected no output for empty URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintQRCode_WritesOutput(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "https://example.com/auth")
|
||||
|
||||
if buf.Len() == 0 {
|
||||
t.Error("expected QR code output for non-empty URL")
|
||||
}
|
||||
}
|
||||
@@ -75,6 +75,7 @@ var (
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
|
||||
@@ -44,6 +44,7 @@ func init() {
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
|
||||
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
|
||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
}
|
||||
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled)
|
||||
if err := serverInstance.Start(); err != nil {
|
||||
log.Fatalf("failed to start daemon: %v", err)
|
||||
}
|
||||
|
||||
@@ -59,6 +59,10 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-update-settings")
|
||||
}
|
||||
|
||||
if captureEnabled {
|
||||
args = append(args, "--enable-capture")
|
||||
}
|
||||
|
||||
if networksDisabled {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type serviceParams struct {
|
||||
LogFiles []string `json:"log_files,omitempty"`
|
||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
@@ -79,6 +80,7 @@ func currentServiceParams() *serviceParams {
|
||||
LogFiles: logFiles,
|
||||
DisableProfiles: profilesDisabled,
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
EnableCapture: captureEnabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
}
|
||||
|
||||
@@ -144,6 +146,10 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
updateSettingsDisabled = params.DisableUpdateSettings
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
|
||||
captureEnabled = params.EnableCapture
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
@@ -535,6 +535,7 @@ func fieldToGlobalVar(field string) string {
|
||||
"LogFiles": "logFiles",
|
||||
"DisableProfiles": "profilesDisabled",
|
||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||
"EnableCapture": "captureEnabled",
|
||||
"DisableNetworks": "networksDisabled",
|
||||
"ServiceEnvVars": "serviceEnvVars",
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ func startClientDaemon(
|
||||
s := grpc.NewServer()
|
||||
|
||||
server := client.New(ctx,
|
||||
"", "", false, false, false)
|
||||
"", "", false, false, false, false)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,9 @@ const (
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
|
||||
showQRFlag = "qr"
|
||||
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
|
||||
|
||||
profileNameFlag = "profile"
|
||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||
)
|
||||
@@ -48,6 +51,7 @@ var (
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
showQR bool
|
||||
profileName string
|
||||
configPath string
|
||||
|
||||
@@ -80,6 +84,7 @@ func init() {
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||
|
||||
|
||||
65
client/embed/capture.go
Normal file
65
client/embed/capture.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
// CaptureOptions configures a packet capture session.
|
||||
type CaptureOptions struct {
|
||||
// Output receives pcap-formatted data. Nil disables pcap output.
|
||||
Output io.Writer
|
||||
// TextOutput receives human-readable packet summaries. Nil disables text output.
|
||||
TextOutput io.Writer
|
||||
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
|
||||
// Empty captures all packets.
|
||||
Filter string
|
||||
// Verbose adds seq/ack, TTL, window, and total length to text output.
|
||||
Verbose bool
|
||||
// ASCII dumps transport payload as printable ASCII after each packet line.
|
||||
ASCII bool
|
||||
}
|
||||
|
||||
// CaptureStats reports capture session counters.
|
||||
type CaptureStats struct {
|
||||
Packets int64
|
||||
Bytes int64
|
||||
Dropped int64
|
||||
}
|
||||
|
||||
// CaptureSession represents an active packet capture. Call Stop to end the
|
||||
// capture and flush buffered packets.
|
||||
type CaptureSession struct {
|
||||
sess *capture.Session
|
||||
engine *internal.Engine
|
||||
}
|
||||
|
||||
// Stop ends the capture, flushes remaining packets, and detaches from the device.
|
||||
// Safe to call multiple times.
|
||||
func (cs *CaptureSession) Stop() {
|
||||
if cs.engine != nil {
|
||||
_ = cs.engine.SetCapture(nil)
|
||||
cs.engine = nil
|
||||
}
|
||||
if cs.sess != nil {
|
||||
cs.sess.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns current capture counters.
|
||||
func (cs *CaptureSession) Stats() CaptureStats {
|
||||
s := cs.sess.Stats()
|
||||
return CaptureStats{
|
||||
Packets: s.Packets,
|
||||
Bytes: s.Bytes,
|
||||
Dropped: s.Dropped,
|
||||
}
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the capture's writer goroutine
|
||||
// has fully exited and all buffered packets have been flushed.
|
||||
func (cs *CaptureSession) Done() <-chan struct{} {
|
||||
return cs.sess.Done()
|
||||
}
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -65,7 +66,7 @@ type Options struct {
|
||||
PrivateKey string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
// PreSharedKey is the pre-shared key for the tunnel interface
|
||||
PreSharedKey string
|
||||
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||
LogOutput io.Writer
|
||||
@@ -81,9 +82,9 @@ type Options struct {
|
||||
DisableClientRoutes bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the WireGuard interface.
|
||||
// MTU is the MTU for the tunnel interface.
|
||||
// Valid values are in the range 576..8192 bytes.
|
||||
// If non-nil, this value overrides any value stored in the config file.
|
||||
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
||||
@@ -469,6 +470,52 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var matcher capture.Matcher
|
||||
if opts.Filter != "" {
|
||||
m, err := capture.ParseFilter(opts.Filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse filter: %w", err)
|
||||
}
|
||||
matcher = m
|
||||
}
|
||||
|
||||
sess, err := capture.NewSession(capture.Options{
|
||||
Output: opts.Output,
|
||||
TextOutput: opts.TextOutput,
|
||||
Matcher: matcher,
|
||||
Verbose: opts.Verbose,
|
||||
ASCII: opts.ASCII,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create capture session: %w", err)
|
||||
}
|
||||
|
||||
if err := engine.SetCapture(sess); err != nil {
|
||||
sess.Stop()
|
||||
return nil, fmt.Errorf("set capture: %w", err)
|
||||
}
|
||||
|
||||
return &CaptureSession{sess: sess, engine: engine}, nil
|
||||
}
|
||||
|
||||
// StopCapture stops the active capture session if one is running.
|
||||
func (c *Client) StopCapture() error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetCapture(nil)
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
|
||||
@@ -115,12 +115,13 @@ type Manager struct {
|
||||
|
||||
localipmanager *localIPManager
|
||||
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
pendingCapture atomic.Pointer[forwarder.PacketCapture]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRule firewall.Rule
|
||||
|
||||
@@ -351,6 +352,19 @@ func (m *Manager) determineRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPacketCapture sets or clears packet capture on the forwarder endpoint.
|
||||
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
|
||||
func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) {
|
||||
if pc == nil {
|
||||
m.pendingCapture.Store(nil)
|
||||
} else {
|
||||
m.pendingCapture.Store(&pc)
|
||||
}
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.SetCapture(pc)
|
||||
}
|
||||
}
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder.Load() != nil {
|
||||
@@ -372,6 +386,11 @@ func (m *Manager) initForwarder() error {
|
||||
|
||||
m.forwarder.Store(forwarder)
|
||||
|
||||
// Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture.
|
||||
if pc := m.pendingCapture.Load(); pc != nil {
|
||||
forwarder.SetCapture(*pc)
|
||||
}
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
return nil
|
||||
@@ -614,6 +633,7 @@ func (m *Manager) resetState() {
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.SetCapture(nil)
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
|
||||
@@ -12,12 +12,19 @@ import (
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
// PacketCapture captures raw packets for debugging. Implementations must be
|
||||
// safe for concurrent use and must not block.
|
||||
type PacketCapture interface {
|
||||
Offer(data []byte, outbound bool)
|
||||
}
|
||||
|
||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||
type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu atomic.Uint32
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -54,13 +61,17 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet through WireGuard
|
||||
pktBytes := data.AsSlice()
|
||||
|
||||
address := netHeader.DestinationAddress()
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
|
||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if pc := e.capture.Load(); pc != nil {
|
||||
(*pc).Offer(pktBytes, true)
|
||||
}
|
||||
written++
|
||||
}
|
||||
|
||||
|
||||
@@ -139,6 +139,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture on the forwarder endpoint.
|
||||
// This captures outbound packets that bypass the FilteredDevice (netstack forwarding).
|
||||
func (f *Forwarder) SetCapture(pc PacketCapture) {
|
||||
if pc == nil {
|
||||
f.endpoint.capture.Store(nil)
|
||||
return
|
||||
}
|
||||
f.endpoint.capture.Store(&pc)
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
|
||||
@@ -270,5 +270,9 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []
|
||||
return 0
|
||||
}
|
||||
|
||||
if pc := f.endpoint.capture.Load(); pc != nil {
|
||||
(*pc).Offer(fullPacket, true)
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package device
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
@@ -28,11 +29,20 @@ type PacketFilter interface {
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
}
|
||||
|
||||
// PacketCapture captures raw packets for debugging. Implementations must be
|
||||
// safe for concurrent use and must not block.
|
||||
type PacketCapture interface {
|
||||
// Offer submits a packet for capture. outbound is true for packets
|
||||
// leaving the host (Read path), false for packets arriving (Write path).
|
||||
Offer(data []byte, outbound bool)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
@@ -63,20 +73,25 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
filter := d.filter
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if filter == nil {
|
||||
return
|
||||
if filter != nil {
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
i--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
i--
|
||||
if pc := d.capture.Load(); pc != nil {
|
||||
for i := 0; i < n; i++ {
|
||||
(*pc).Offer(bufs[i][offset:offset+sizes[i]], true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +100,13 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
|
||||
// Write wraps write method with filtering feature
|
||||
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
// Capture before filtering so dropped packets are still visible in captures.
|
||||
if pc := d.capture.Load(); pc != nil {
|
||||
for _, buf := range bufs {
|
||||
(*pc).Offer(buf[offset:], false)
|
||||
}
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
filter := d.filter
|
||||
d.mutex.RUnlock()
|
||||
@@ -96,9 +118,10 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
if filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
dropped++
|
||||
} else {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,3 +136,14 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
|
||||
d.filter = filter
|
||||
d.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
|
||||
// Uses atomic store so the hot path (Read/Write) is a single pointer load
|
||||
// with no locking overhead when capture is off.
|
||||
func (d *FilteredDevice) SetCapture(pc PacketCapture) {
|
||||
if pc == nil {
|
||||
d.capture.Store(nil)
|
||||
return
|
||||
}
|
||||
d.capture.Store(&pc)
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if n != 0 {
|
||||
if n != 1 {
|
||||
t.Errorf("expected n=1, got %d", n)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
cpu.prof: CPU profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
@@ -234,6 +235,7 @@ type BundleGenerator struct {
|
||||
logPath string
|
||||
tempDir string
|
||||
cpuProfile []byte
|
||||
capturePath string
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
|
||||
@@ -257,7 +259,8 @@ type GeneratorDependencies struct {
|
||||
LogPath string
|
||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
CapturePath string
|
||||
RefreshStatus func()
|
||||
ClientMetrics MetricsExporter
|
||||
}
|
||||
|
||||
@@ -277,6 +280,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
logPath: deps.LogPath,
|
||||
tempDir: deps.TempDir,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
capturePath: deps.CapturePath,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
|
||||
@@ -346,6 +350,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addCaptureFile(); err != nil {
|
||||
log.Errorf("failed to add capture file to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
@@ -669,6 +677,29 @@ func (g *BundleGenerator) addCPUProfile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCaptureFile() error {
|
||||
if g.capturePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
log.Info("skipping capture file in anonymized bundle (contains raw packet data)")
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(g.capturePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open capture file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := g.addFileToZip(f, "capture.pcap"); err != nil {
|
||||
return fmt.Errorf("add capture file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
@@ -68,6 +69,7 @@ import (
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -218,6 +220,8 @@ type Engine struct {
|
||||
portForwardManager *portforward.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
afpacketCapture *capture.AFPacketCapture
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
@@ -1703,6 +1707,11 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
}
|
||||
|
||||
func (e *Engine) close() {
|
||||
if e.afpacketCapture != nil {
|
||||
e.afpacketCapture.Stop()
|
||||
e.afpacketCapture = nil
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
@@ -2168,6 +2177,62 @@ func (e *Engine) Address() (netip.Addr, error) {
|
||||
return e.wgInterface.Address().IP, nil
|
||||
}
|
||||
|
||||
// SetCapture sets or clears packet capture on the WireGuard device.
|
||||
// On userspace WireGuard, it taps the FilteredDevice directly.
|
||||
// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture.
|
||||
// Pass nil to disable capture.
|
||||
func (e *Engine) SetCapture(pc device.PacketCapture) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
intf := e.wgInterface
|
||||
if intf == nil {
|
||||
return errors.New("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
if e.afpacketCapture != nil {
|
||||
e.afpacketCapture.Stop()
|
||||
e.afpacketCapture = nil
|
||||
}
|
||||
|
||||
dev := intf.GetDevice()
|
||||
if dev != nil {
|
||||
dev.SetCapture(pc)
|
||||
e.setForwarderCapture(pc)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Kernel mode: no FilteredDevice. Use AF_PACKET on Linux.
|
||||
if pc == nil {
|
||||
return nil
|
||||
}
|
||||
sess, ok := pc.(*capture.Session)
|
||||
if !ok {
|
||||
return errors.New("filtered device not available and AF_PACKET requires *capture.Session")
|
||||
}
|
||||
|
||||
afc := capture.NewAFPacketCapture(intf.Name(), sess)
|
||||
if err := afc.Start(); err != nil {
|
||||
return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err)
|
||||
}
|
||||
e.afpacketCapture = afc
|
||||
return nil
|
||||
}
|
||||
|
||||
// setForwarderCapture propagates capture to the USP filter's forwarder endpoint.
|
||||
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
|
||||
func (e *Engine) setForwarderCapture(pc device.PacketCapture) {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
type forwarderCapturer interface {
|
||||
SetPacketCapture(pc forwarder.PacketCapture)
|
||||
}
|
||||
if fc, ok := e.firewall.(forwarderCapturer); ok {
|
||||
fc.SetPacketCapture(pc)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||
if e.firewall == nil {
|
||||
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||
@@ -2389,6 +2454,8 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||
}
|
||||
}
|
||||
|
||||
relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())
|
||||
|
||||
offerAnswer := peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
@@ -2399,7 +2466,23 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassAddr: rosenpassAddr,
|
||||
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
|
||||
RelaySrvIP: relayIP,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
return &offerAnswer, nil
|
||||
}
|
||||
|
||||
// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
|
||||
// netip.Addr. Returns the zero value for empty input and logs a warning
|
||||
// for malformed payloads.
|
||||
func decodeRelayIP(b []byte) netip.Addr {
|
||||
if len(b) == 0 {
|
||||
return netip.Addr{}
|
||||
}
|
||||
ip, ok := netip.AddrFromSlice(b)
|
||||
if !ok {
|
||||
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
|
||||
return netip.Addr{}
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||
@@ -91,8 +90,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
|
||||
m.routesMu.Lock()
|
||||
defer m.routesMu.Unlock()
|
||||
|
||||
maps.Clear(m.peerToHAGroups)
|
||||
maps.Clear(m.haGroupToPeers)
|
||||
clear(m.peerToHAGroups)
|
||||
clear(m.haGroupToPeers)
|
||||
|
||||
for haUniqueID, routes := range haMap {
|
||||
var peers []string
|
||||
|
||||
@@ -3,8 +3,6 @@ package store
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
@@ -30,7 +28,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
|
||||
func (m *Memory) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
maps.Clear(m.events)
|
||||
clear(m.events)
|
||||
}
|
||||
|
||||
func (m *Memory) GetEvents() []*types.Event {
|
||||
|
||||
@@ -3,6 +3,7 @@ package peer
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -40,6 +41,10 @@ type OfferAnswer struct {
|
||||
|
||||
// relay server address
|
||||
RelaySrvAddress string
|
||||
// RelaySrvIP is the IP the remote peer is connected to on its
|
||||
// relay server. Used as a dial target if DNS for RelaySrvAddress
|
||||
// fails. Zero value if the peer did not advertise an IP.
|
||||
RelaySrvIP netip.Addr
|
||||
// SessionID is the unique identifier of the session, used to discard old messages
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
@@ -217,8 +222,9 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
answer.SessionID = &sid
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
answer.RelaySrvAddress = addr
|
||||
answer.RelaySrvIP = ip
|
||||
}
|
||||
|
||||
return answer
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
type mocListener struct {
|
||||
lastState int
|
||||
wg sync.WaitGroup
|
||||
peersWg sync.WaitGroup
|
||||
peers int
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ func (l *mocListener) OnAddressChanged(host, addr string) {
|
||||
}
|
||||
func (l *mocListener) OnPeersListChanged(size int) {
|
||||
l.peers = size
|
||||
l.peersWg.Done()
|
||||
}
|
||||
|
||||
func (l *mocListener) setWaiter() {
|
||||
@@ -43,6 +45,14 @@ func (l *mocListener) wait() {
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
func (l *mocListener) setPeersWaiter() {
|
||||
l.peersWg.Add(1)
|
||||
}
|
||||
|
||||
func (l *mocListener) waitPeers() {
|
||||
l.peersWg.Wait()
|
||||
}
|
||||
|
||||
func Test_notifier_serverState(t *testing.T) {
|
||||
|
||||
type scenario struct {
|
||||
@@ -72,11 +82,13 @@ func Test_notifier_serverState(t *testing.T) {
|
||||
func Test_notifier_SetListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
listener.setPeersWaiter()
|
||||
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
listener.wait()
|
||||
listener.waitPeers()
|
||||
if listener.lastState != n.lastNotification {
|
||||
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
||||
}
|
||||
@@ -85,9 +97,14 @@ func Test_notifier_SetListener(t *testing.T) {
|
||||
func Test_notifier_RemoveListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
listener.setPeersWaiter()
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
// setListener replays cached state on a goroutine; wait for both the state
|
||||
// and peers callbacks to finish so we don't race on listener.peers.
|
||||
listener.wait()
|
||||
listener.waitPeers()
|
||||
n.removeListener()
|
||||
n.peerListChanged(1)
|
||||
|
||||
|
||||
@@ -54,19 +54,19 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
offerAnswer.WgListenPort,
|
||||
remoteKey,
|
||||
&signal.Credential{
|
||||
msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{
|
||||
Type: bodyType,
|
||||
WgListenPort: offerAnswer.WgListenPort,
|
||||
Credential: &signal.Credential{
|
||||
UFrag: offerAnswer.IceCredentials.UFrag,
|
||||
Pwd: offerAnswer.IceCredentials.Pwd,
|
||||
},
|
||||
bodyType,
|
||||
offerAnswer.RosenpassPubKey,
|
||||
offerAnswer.RosenpassAddr,
|
||||
offerAnswer.RelaySrvAddress,
|
||||
sessionIDBytes)
|
||||
RosenpassPubKey: offerAnswer.RosenpassPubKey,
|
||||
RosenpassAddr: offerAnswer.RosenpassAddr,
|
||||
RelaySrvAddress: offerAnswer.RelaySrvAddress,
|
||||
RelaySrvIP: offerAnswer.RelaySrvIP,
|
||||
SessionID: sessionIDBytes,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
// UpdatePeerState updates peer status
|
||||
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -343,23 +343,29 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
// when we close the connection we will not notify the router manager
|
||||
if receivedState.ConnStatus == StatusIdle {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
notifyRouter := receivedState.ConnStatus == StatusIdle
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -371,17 +377,20 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
||||
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
numPeers := d.numOfPeers()
|
||||
d.mux.Unlock()
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -393,8 +402,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.routeIDLookup.RemoveRemoteRouteID(pref)
|
||||
}
|
||||
|
||||
numPeers := d.numOfPeers()
|
||||
d.mux.Unlock()
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -410,10 +422,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
|
||||
|
||||
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -431,22 +443,28 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -461,22 +479,28 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -490,22 +514,28 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
d.mux.Unlock()
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
@@ -522,12 +552,18 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||
d.notifyPeerListChanged()
|
||||
}
|
||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.mux.Unlock()
|
||||
|
||||
if notifyList {
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
}
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -594,17 +630,33 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if !d.peerListChangedForNotification {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.peerListChangedForNotification = false
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
numPeers := d.numOfPeers()
|
||||
|
||||
// snapshot per-peer router state to deliver after the lock is released
|
||||
type routerDispatch struct {
|
||||
peerID string
|
||||
snapshot map[string]RouterState
|
||||
}
|
||||
dispatches := make([]routerDispatch, 0, len(d.peers))
|
||||
for key := range d.peers {
|
||||
d.notifyPeerStateChangeListeners(key)
|
||||
snapshot := d.snapshotRouterPeersLocked(key, true)
|
||||
if snapshot != nil {
|
||||
dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot})
|
||||
}
|
||||
}
|
||||
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
for _, rd := range dispatches {
|
||||
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -655,10 +707,12 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
// UpdateLocalPeerState updates local peer status
|
||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer = localPeerState
|
||||
d.notifyAddressChanged()
|
||||
fqdn := d.localPeer.FQDN
|
||||
ip := d.localPeer.IP
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
}
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
@@ -721,30 +775,36 @@ func (d *Status) CleanLocalPeerStateRoutes() {
|
||||
// CleanLocalPeerState cleans local peer status
|
||||
func (d *Status) CleanLocalPeerState() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer = LocalPeerState{}
|
||||
d.notifyAddressChanged()
|
||||
fqdn := d.localPeer.FQDN
|
||||
ip := d.localPeer.IP
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = false
|
||||
d.managementError = err
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = true
|
||||
d.managementError = nil
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
@@ -778,21 +838,25 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = false
|
||||
d.signalError = err
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = true
|
||||
d.signalError = nil
|
||||
mgm := d.managementState
|
||||
sig := d.signalState
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
@@ -919,7 +983,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
@@ -1012,18 +1076,17 @@ func (d *Status) RemoveConnectionListener() {
|
||||
d.notifier.removeListener()
|
||||
}
|
||||
|
||||
func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
subs, ok := d.changeNotify[peerID]
|
||||
if !ok {
|
||||
return
|
||||
// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers.
|
||||
// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID
|
||||
// or when notify is false. The snapshot is consumed later by dispatchRouterPeers
|
||||
// outside the lock so the channel send cannot stall any d.mux holder.
|
||||
func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState {
|
||||
if !notify {
|
||||
return nil
|
||||
}
|
||||
if _, ok := d.changeNotify[peerID]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// collect the relevant data for router peers
|
||||
routerPeers := make(map[string]RouterState, len(d.changeNotify))
|
||||
for pid := range d.changeNotify {
|
||||
s, ok := d.peers[pid]
|
||||
@@ -1031,13 +1094,35 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
log.Warnf("router peer not found in peers list: %s", pid)
|
||||
continue
|
||||
}
|
||||
|
||||
routerPeers[pid] = RouterState{
|
||||
Status: s.ConnStatus,
|
||||
Relayed: s.Relayed,
|
||||
Latency: s.Latency,
|
||||
}
|
||||
}
|
||||
return routerPeers
|
||||
}
|
||||
|
||||
// dispatchRouterPeers delivers a previously snapshotted router-state map to
|
||||
// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a
|
||||
// fresh, short read of d.changeNotify under the lock to grab subscriber
|
||||
// channels, then sends outside the lock so a slow consumer cannot block other
|
||||
// d.mux holders. The send itself stays blocking (only short-circuited by the
|
||||
// subscriber's context) so peer state transitions are not silently dropped.
|
||||
func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) {
|
||||
if routerPeers == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
subsMap, ok := d.changeNotify[peerID]
|
||||
subs := make([]*StatusChangeSubscription, 0, len(subsMap))
|
||||
if ok {
|
||||
for _, sub := range subsMap {
|
||||
subs = append(subs, sub)
|
||||
}
|
||||
}
|
||||
d.mux.Unlock()
|
||||
|
||||
for _, sub := range subs {
|
||||
select {
|
||||
@@ -1047,14 +1132,6 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(d.numOfPeers())
|
||||
}
|
||||
|
||||
func (d *Status) notifyAddressChanged() {
|
||||
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||
}
|
||||
|
||||
func (d *Status) numOfPeers() int {
|
||||
return len(d.peers) + len(d.offlinePeers)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -53,15 +54,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.relaySupportedOnRemotePeer.Store(true)
|
||||
|
||||
// the relayManager will return with error in case if the connection has lost with relay server
|
||||
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
|
||||
currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to handle new offer: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||
var serverIP netip.Addr
|
||||
if srv == remoteOfferAnswer.RelaySrvAddress {
|
||||
serverIP = remoteOfferAnswer.RelaySrvIP
|
||||
}
|
||||
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP)
|
||||
if err != nil {
|
||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||
@@ -90,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) {
|
||||
return w.relayManager.RelayInstanceAddress()
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -44,8 +43,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
maps.Clear(rs.deselectedRoutes)
|
||||
maps.Clear(rs.selectedRoutes)
|
||||
clear(rs.deselectedRoutes)
|
||||
clear(rs.selectedRoutes)
|
||||
for _, r := range allRoutes {
|
||||
rs.deselectedRoutes[r] = struct{}{}
|
||||
}
|
||||
@@ -78,8 +77,8 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
maps.Clear(rs.deselectedRoutes)
|
||||
maps.Clear(rs.selectedRoutes)
|
||||
clear(rs.deselectedRoutes)
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// DeselectRoutes removes specific routes from the selection.
|
||||
@@ -116,8 +115,8 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
maps.Clear(rs.deselectedRoutes)
|
||||
maps.Clear(rs.selectedRoutes)
|
||||
clear(rs.deselectedRoutes)
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -28,6 +27,10 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||
|
||||
// Start begins monitoring the WireGuard interface.
|
||||
// It relies on the provided context cancellation to stop.
|
||||
//
|
||||
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
|
||||
// to avoid the allocation churn of repeatedly dumping the kernel link
|
||||
// table; on other platforms it falls back to a low-frequency poll.
|
||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||
defer close(m.done)
|
||||
|
||||
@@ -56,31 +59,7 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
||||
|
||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return watchInterface(ctx, ifaceName, expectedIndex)
|
||||
}
|
||||
|
||||
// getInterfaceIndex returns the index of a network interface by name.
|
||||
|
||||
134
client/internal/wg_iface_monitor_linux.go
Normal file
134
client/internal/wg_iface_monitor_linux.go
Normal file
@@ -0,0 +1,134 @@
|
||||
//go:build linux
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
|
||||
// deletion or recreation of the WireGuard interface.
|
||||
//
|
||||
// The previous implementation polled net.InterfaceByName every 2 s, which
|
||||
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
|
||||
// entire kernel link table on every call. On hosts with many veth
|
||||
// interfaces (containers, bridges) the resulting allocation churn was on
|
||||
// the order of ~1 GB/day from this single ticker, which on small ARM
|
||||
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
|
||||
//
|
||||
// The event-driven version below allocates only when the kernel actually
|
||||
// publishes a link event for the tracked interface — typically zero
|
||||
// allocations between events.
|
||||
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
// Buffer the channel to absorb event bursts (e.g. when many veth
|
||||
// pairs are created/destroyed at once by container runtimes).
|
||||
linkChan := make(chan netlink.LinkUpdate, 32)
|
||||
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||
// Return shouldRestart=true so the engine recovers monitoring
|
||||
// via triggerClientRestart instead of silently losing it for
|
||||
// the rest of the process lifetime.
|
||||
return true, fmt.Errorf("subscribe to link updates: %w", err)
|
||||
}
|
||||
|
||||
// Race window: the interface could have been deleted (or recreated)
|
||||
// between the initial getInterfaceIndex() in Start and LinkSubscribe
|
||||
// completing its handshake with the kernel. Re-check explicitly so we
|
||||
// do not block forever waiting for an event that already fired.
|
||||
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
|
||||
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
} else if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
||||
|
||||
case update, ok := <-linkChan:
|
||||
if !ok {
|
||||
// The vishvananda/netlink subscription goroutine closes
|
||||
// the channel on receive errors. Signal the engine to
|
||||
// restart so monitoring is re-established instead of
|
||||
// silently ending.
|
||||
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
|
||||
return true, fmt.Errorf("link subscription channel closed unexpectedly")
|
||||
}
|
||||
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inspectLinkEvent classifies a single netlink link update against the
|
||||
// tracked WireGuard interface. It returns (true, err) when the engine
|
||||
// should restart monitoring; (false, nil) means the event is unrelated
|
||||
// and the caller should keep waiting.
|
||||
//
|
||||
// The error component, when non-nil, describes the kernel-side reason
|
||||
// (deletion or rename); the recreation case returns (true, nil) since
|
||||
// no error condition is reported.
|
||||
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
|
||||
eventIndex := int(update.Index)
|
||||
eventName := ""
|
||||
if attrs := update.Attrs(); attrs != nil {
|
||||
eventName = attrs.Name
|
||||
}
|
||||
|
||||
switch update.Header.Type {
|
||||
case syscall.RTM_DELLINK:
|
||||
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
|
||||
case syscall.RTM_NEWLINK:
|
||||
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
|
||||
// tracked interface index.
|
||||
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
|
||||
if eventIndex != expectedIndex {
|
||||
return false, nil
|
||||
}
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted", ifaceName)
|
||||
}
|
||||
|
||||
// inspectNewLink reports a restart when an RTM_NEWLINK either:
|
||||
//
|
||||
// 1. Introduces a link with our name at a different index (recreation
|
||||
// after a delete), or
|
||||
//
|
||||
// 2. Reports a link still at our index but with a different name
|
||||
// (in-place rename). The previous polling implementation caught
|
||||
// this implicitly because net.InterfaceByName(ifaceName) would
|
||||
// start failing; the event-driven version has to test it.
|
||||
//
|
||||
// Same name + same index is just a flag/state change on the existing
|
||||
// interface and is ignored.
|
||||
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
|
||||
if eventName == ifaceName && eventIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, eventIndex)
|
||||
return true, nil
|
||||
}
|
||||
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
|
||||
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
|
||||
ifaceName, eventName, expectedIndex)
|
||||
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
56
client/internal/wg_iface_monitor_other.go
Normal file
56
client/internal/wg_iface_monitor_other.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build !linux
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// watchInterface polls net.InterfaceByName at a fixed interval to detect
|
||||
// deletion or recreation of the WireGuard interface.
|
||||
//
|
||||
// This is the fallback used on non-Linux desktop and server platforms
|
||||
// (darwin, windows, freebsd). It is also compiled on android and ios so
|
||||
// the package builds on every supported GOOS, but it is never reached
|
||||
// at runtime there because Start() in wg_iface_monitor.go exits early
|
||||
// on mobile platforms.
|
||||
//
|
||||
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
|
||||
// RTNLGRP_LINK netlink subscription instead, because on Linux
|
||||
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
|
||||
// dumps the entire kernel link table on every call and produces
|
||||
// significant allocation churn (netbirdio/netbird#3678).
|
||||
//
|
||||
// Windows is also reported in #3678 as affected by RSS climb. A future
|
||||
// follow-up could implement an event-driven watcher there using
|
||||
// NotifyIpInterfaceChange from iphlpapi.
|
||||
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1182,6 +1182,7 @@ type GetConfigResponse struct {
|
||||
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
|
||||
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
|
||||
DisableFirewall bool `protobuf:"varint,27,opt,name=disable_firewall,json=disableFirewall,proto3" json:"disable_firewall,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1398,6 +1399,13 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetDisableFirewall() bool {
|
||||
if x != nil {
|
||||
return x.DisableFirewall
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
type PeerState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
@@ -5847,6 +5855,288 @@ func (x *ExposeServiceReady) GetPortAutoAssigned() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type StartCaptureRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
TextOutput bool `protobuf:"varint,1,opt,name=text_output,json=textOutput,proto3" json:"text_output,omitempty"`
|
||||
SnapLen uint32 `protobuf:"varint,2,opt,name=snap_len,json=snapLen,proto3" json:"snap_len,omitempty"`
|
||||
Duration *durationpb.Duration `protobuf:"bytes,3,opt,name=duration,proto3" json:"duration,omitempty"`
|
||||
FilterExpr string `protobuf:"bytes,4,opt,name=filter_expr,json=filterExpr,proto3" json:"filter_expr,omitempty"`
|
||||
Verbose bool `protobuf:"varint,5,opt,name=verbose,proto3" json:"verbose,omitempty"`
|
||||
Ascii bool `protobuf:"varint,6,opt,name=ascii,proto3" json:"ascii,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) Reset() {
|
||||
*x = StartCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StartCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StartCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{88}
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetTextOutput() bool {
|
||||
if x != nil {
|
||||
return x.TextOutput
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetSnapLen() uint32 {
|
||||
if x != nil {
|
||||
return x.SnapLen
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetDuration() *durationpb.Duration {
|
||||
if x != nil {
|
||||
return x.Duration
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetFilterExpr() string {
|
||||
if x != nil {
|
||||
return x.FilterExpr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetVerbose() bool {
|
||||
if x != nil {
|
||||
return x.Verbose
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetAscii() bool {
|
||||
if x != nil {
|
||||
return x.Ascii
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type CapturePacket struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *CapturePacket) Reset() {
|
||||
*x = CapturePacket{}
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *CapturePacket) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*CapturePacket) ProtoMessage() {}
|
||||
|
||||
func (x *CapturePacket) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use CapturePacket.ProtoReflect.Descriptor instead.
|
||||
func (*CapturePacket) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{89}
|
||||
}
|
||||
|
||||
func (x *CapturePacket) GetData() []byte {
|
||||
if x != nil {
|
||||
return x.Data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type StartBundleCaptureRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// timeout auto-stops the capture after this duration.
|
||||
// Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum.
|
||||
Timeout *durationpb.Duration `protobuf:"bytes,1,opt,name=timeout,proto3" json:"timeout,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureRequest) Reset() {
|
||||
*x = StartBundleCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StartBundleCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StartBundleCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartBundleCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{90}
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureRequest) GetTimeout() *durationpb.Duration {
|
||||
if x != nil {
|
||||
return x.Timeout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type StartBundleCaptureResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureResponse) Reset() {
|
||||
*x = StartBundleCaptureResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StartBundleCaptureResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StartBundleCaptureResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StartBundleCaptureResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{91}
|
||||
}
|
||||
|
||||
type StopBundleCaptureRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StopBundleCaptureRequest) Reset() {
|
||||
*x = StopBundleCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StopBundleCaptureRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StopBundleCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StopBundleCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopBundleCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{92}
|
||||
}
|
||||
|
||||
type StopBundleCaptureResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StopBundleCaptureResponse) Reset() {
|
||||
*x = StopBundleCaptureResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StopBundleCaptureResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StopBundleCaptureResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StopBundleCaptureResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StopBundleCaptureResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{93}
|
||||
}
|
||||
|
||||
type PortInfo_Range struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
||||
@@ -5857,7 +6147,7 @@ type PortInfo_Range struct {
|
||||
|
||||
func (x *PortInfo_Range) Reset() {
|
||||
*x = PortInfo_Range{}
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
mi := &file_daemon_proto_msgTypes[95]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5869,7 +6159,7 @@ func (x *PortInfo_Range) String() string {
|
||||
func (*PortInfo_Range) ProtoMessage() {}
|
||||
|
||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
mi := &file_daemon_proto_msgTypes[95]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6008,7 +6298,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\fDownResponse\"P\n" +
|
||||
"\x10GetConfigRequest\x12 \n" +
|
||||
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
|
||||
"\busername\x18\x02 \x01(\tR\busername\"\xdb\b\n" +
|
||||
"\busername\x18\x02 \x01(\tR\busername\"\x86\t\n" +
|
||||
"\x11GetConfigResponse\x12$\n" +
|
||||
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
|
||||
"\n" +
|
||||
@@ -6039,7 +6329,8 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" +
|
||||
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
|
||||
"\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" +
|
||||
"\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\"\xfe\x05\n" +
|
||||
"\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\x12)\n" +
|
||||
"\x10disable_firewall\x18\x1b \x01(\bR\x0fdisableFirewall\"\xfe\x05\n" +
|
||||
"\tPeerState\x12\x0e\n" +
|
||||
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
|
||||
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
|
||||
@@ -6410,7 +6701,23 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\vservice_url\x18\x02 \x01(\tR\n" +
|
||||
"serviceUrl\x12\x16\n" +
|
||||
"\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" +
|
||||
"\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" +
|
||||
"\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned\"\xd9\x01\n" +
|
||||
"\x13StartCaptureRequest\x12\x1f\n" +
|
||||
"\vtext_output\x18\x01 \x01(\bR\n" +
|
||||
"textOutput\x12\x19\n" +
|
||||
"\bsnap_len\x18\x02 \x01(\rR\asnapLen\x125\n" +
|
||||
"\bduration\x18\x03 \x01(\v2\x19.google.protobuf.DurationR\bduration\x12\x1f\n" +
|
||||
"\vfilter_expr\x18\x04 \x01(\tR\n" +
|
||||
"filterExpr\x12\x18\n" +
|
||||
"\averbose\x18\x05 \x01(\bR\averbose\x12\x14\n" +
|
||||
"\x05ascii\x18\x06 \x01(\bR\x05ascii\"#\n" +
|
||||
"\rCapturePacket\x12\x12\n" +
|
||||
"\x04data\x18\x01 \x01(\fR\x04data\"P\n" +
|
||||
"\x19StartBundleCaptureRequest\x123\n" +
|
||||
"\atimeout\x18\x01 \x01(\v2\x19.google.protobuf.DurationR\atimeout\"\x1c\n" +
|
||||
"\x1aStartBundleCaptureResponse\"\x1a\n" +
|
||||
"\x18StopBundleCaptureRequest\"\x1b\n" +
|
||||
"\x19StopBundleCaptureResponse*b\n" +
|
||||
"\bLogLevel\x12\v\n" +
|
||||
"\aUNKNOWN\x10\x00\x12\t\n" +
|
||||
"\x05PANIC\x10\x01\x12\t\n" +
|
||||
@@ -6428,7 +6735,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"EXPOSE_UDP\x10\x03\x12\x0e\n" +
|
||||
"\n" +
|
||||
"EXPOSE_TLS\x10\x042\xac\x15\n" +
|
||||
"EXPOSE_TLS\x10\x042\xaf\x17\n" +
|
||||
"\rDaemonService\x126\n" +
|
||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||
@@ -6449,7 +6756,10 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"CleanState\x12\x19.daemon.CleanStateRequest\x1a\x1a.daemon.CleanStateResponse\"\x00\x12H\n" +
|
||||
"\vDeleteState\x12\x1a.daemon.DeleteStateRequest\x1a\x1b.daemon.DeleteStateResponse\"\x00\x12u\n" +
|
||||
"\x1aSetSyncResponsePersistence\x12).daemon.SetSyncResponsePersistenceRequest\x1a*.daemon.SetSyncResponsePersistenceResponse\"\x00\x12H\n" +
|
||||
"\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12D\n" +
|
||||
"\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12F\n" +
|
||||
"\fStartCapture\x12\x1b.daemon.StartCaptureRequest\x1a\x15.daemon.CapturePacket\"\x000\x01\x12]\n" +
|
||||
"\x12StartBundleCapture\x12!.daemon.StartBundleCaptureRequest\x1a\".daemon.StartBundleCaptureResponse\"\x00\x12Z\n" +
|
||||
"\x11StopBundleCapture\x12 .daemon.StopBundleCaptureRequest\x1a!.daemon.StopBundleCaptureResponse\"\x00\x12D\n" +
|
||||
"\x0fSubscribeEvents\x12\x18.daemon.SubscribeRequest\x1a\x13.daemon.SystemEvent\"\x000\x01\x12B\n" +
|
||||
"\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00\x12N\n" +
|
||||
"\rSwitchProfile\x12\x1c.daemon.SwitchProfileRequest\x1a\x1d.daemon.SwitchProfileResponse\"\x00\x12B\n" +
|
||||
@@ -6483,7 +6793,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 97)
|
||||
var file_daemon_proto_goTypes = []any{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(ExposeProtocol)(0), // 1: daemon.ExposeProtocol
|
||||
@@ -6577,125 +6887,139 @@ var file_daemon_proto_goTypes = []any{
|
||||
(*ExposeServiceRequest)(nil), // 89: daemon.ExposeServiceRequest
|
||||
(*ExposeServiceEvent)(nil), // 90: daemon.ExposeServiceEvent
|
||||
(*ExposeServiceReady)(nil), // 91: daemon.ExposeServiceReady
|
||||
nil, // 92: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 93: daemon.PortInfo.Range
|
||||
nil, // 94: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 95: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 96: google.protobuf.Timestamp
|
||||
(*StartCaptureRequest)(nil), // 92: daemon.StartCaptureRequest
|
||||
(*CapturePacket)(nil), // 93: daemon.CapturePacket
|
||||
(*StartBundleCaptureRequest)(nil), // 94: daemon.StartBundleCaptureRequest
|
||||
(*StartBundleCaptureResponse)(nil), // 95: daemon.StartBundleCaptureResponse
|
||||
(*StopBundleCaptureRequest)(nil), // 96: daemon.StopBundleCaptureRequest
|
||||
(*StopBundleCaptureResponse)(nil), // 97: daemon.StopBundleCaptureResponse
|
||||
nil, // 98: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 99: daemon.PortInfo.Range
|
||||
nil, // 100: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 101: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 102: google.protobuf.Timestamp
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
95, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
25, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
96, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
96, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
95, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
23, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
20, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
19, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
18, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
17, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||
21, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
22, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
55, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
24, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
31, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
92, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
93, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
32, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
32, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
33, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||
0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
41, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||
50, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
||||
52, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
2, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
3, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
96, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
94, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
55, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
95, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
68, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
1, // 32: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
|
||||
91, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
30, // 34: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
5, // 35: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
7, // 36: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
9, // 37: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
11, // 38: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
13, // 39: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
15, // 40: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
26, // 41: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
||||
28, // 42: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
28, // 43: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
4, // 44: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
|
||||
35, // 45: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
37, // 46: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
39, // 47: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
42, // 48: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||
44, // 49: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||
46, // 50: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
48, // 51: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||
51, // 52: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||
54, // 53: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||
56, // 54: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||
58, // 55: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||
60, // 56: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
|
||||
62, // 57: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
|
||||
64, // 58: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
|
||||
66, // 59: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
|
||||
69, // 60: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
71, // 61: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
73, // 62: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
75, // 63: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
|
||||
77, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
79, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
81, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
83, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
85, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
87, // 69: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
89, // 70: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
6, // 71: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
8, // 72: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
10, // 73: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
12, // 74: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
14, // 75: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
16, // 76: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
27, // 77: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
29, // 78: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
29, // 79: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
34, // 80: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
36, // 81: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
38, // 82: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
40, // 83: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
43, // 84: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
45, // 85: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
47, // 86: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
49, // 87: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
53, // 88: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
55, // 89: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
57, // 90: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
59, // 91: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
61, // 92: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
63, // 93: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
65, // 94: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
67, // 95: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
70, // 96: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
72, // 97: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
74, // 98: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
76, // 99: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
|
||||
78, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
80, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
82, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
84, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
86, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
88, // 105: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
90, // 106: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
71, // [71:107] is the sub-list for method output_type
|
||||
35, // [35:71] is the sub-list for method input_type
|
||||
35, // [35:35] is the sub-list for extension type_name
|
||||
35, // [35:35] is the sub-list for extension extendee
|
||||
0, // [0:35] is the sub-list for field type_name
|
||||
101, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
25, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
102, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
102, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
101, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
23, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
20, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
19, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
18, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
17, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||
21, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
22, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
55, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
24, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
31, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
98, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
99, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
32, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
32, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
33, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||
0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
41, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||
50, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
|
||||
52, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
2, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
3, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
102, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
100, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
55, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
101, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
68, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
1, // 32: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
|
||||
91, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
101, // 34: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration
|
||||
101, // 35: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration
|
||||
30, // 36: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
5, // 37: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
7, // 38: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
9, // 39: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
11, // 40: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
13, // 41: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
15, // 42: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
26, // 43: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
|
||||
28, // 44: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
28, // 45: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
|
||||
4, // 46: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
|
||||
35, // 47: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
37, // 48: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
39, // 49: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
42, // 50: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||
44, // 51: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||
46, // 52: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
48, // 53: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||
51, // 54: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||
92, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest
|
||||
94, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest
|
||||
96, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest
|
||||
54, // 58: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||
56, // 59: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||
58, // 60: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||
60, // 61: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
|
||||
62, // 62: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
|
||||
64, // 63: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
|
||||
66, // 64: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
|
||||
69, // 65: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
71, // 66: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
73, // 67: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
75, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
|
||||
77, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
79, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
81, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
83, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
85, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
87, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
89, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
6, // 76: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
8, // 77: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
10, // 78: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
12, // 79: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
14, // 80: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
16, // 81: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
27, // 82: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
29, // 83: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
29, // 84: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
34, // 85: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
36, // 86: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
38, // 87: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
40, // 88: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
43, // 89: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
45, // 90: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
47, // 91: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
49, // 92: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
53, // 93: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
93, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket
|
||||
95, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse
|
||||
97, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse
|
||||
55, // 97: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
57, // 98: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
59, // 99: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
61, // 100: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
63, // 101: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
65, // 102: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
67, // 103: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
70, // 104: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
72, // 105: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
74, // 106: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
76, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
|
||||
78, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
80, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
82, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
84, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
86, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
88, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
90, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
76, // [76:115] is the sub-list for method output_type
|
||||
37, // [37:76] is the sub-list for method input_type
|
||||
37, // [37:37] is the sub-list for extension type_name
|
||||
37, // [37:37] is the sub-list for extension extendee
|
||||
0, // [0:37] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_daemon_proto_init() }
|
||||
@@ -6725,7 +7049,7 @@ func file_daemon_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||
NumEnums: 4,
|
||||
NumMessages: 91,
|
||||
NumMessages: 97,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -64,6 +64,17 @@ service DaemonService {
|
||||
|
||||
rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {}
|
||||
|
||||
// StartCapture begins streaming packet capture on the WireGuard interface.
|
||||
// Requires --enable-capture set at service install/reconfigure time.
|
||||
rpc StartCapture(StartCaptureRequest) returns (stream CapturePacket) {}
|
||||
|
||||
// StartBundleCapture begins capturing packets to a server-side temp file
|
||||
// for inclusion in the next debug bundle. Auto-stops after the given timeout.
|
||||
rpc StartBundleCapture(StartBundleCaptureRequest) returns (StartBundleCaptureResponse) {}
|
||||
|
||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
||||
rpc StopBundleCapture(StopBundleCaptureRequest) returns (StopBundleCaptureResponse) {}
|
||||
|
||||
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
|
||||
|
||||
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
|
||||
@@ -300,6 +311,8 @@ message GetConfigResponse {
|
||||
bool disableSSHAuth = 25;
|
||||
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_firewall = 27;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -832,3 +845,26 @@ message ExposeServiceReady {
|
||||
string domain = 3;
|
||||
bool port_auto_assigned = 4;
|
||||
}
|
||||
|
||||
message StartCaptureRequest {
|
||||
bool text_output = 1;
|
||||
uint32 snap_len = 2;
|
||||
google.protobuf.Duration duration = 3;
|
||||
string filter_expr = 4;
|
||||
bool verbose = 5;
|
||||
bool ascii = 6;
|
||||
}
|
||||
|
||||
message CapturePacket {
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message StartBundleCaptureRequest {
|
||||
// timeout auto-stops the capture after this duration.
|
||||
// Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum.
|
||||
google.protobuf.Duration timeout = 1;
|
||||
}
|
||||
|
||||
message StartBundleCaptureResponse {}
|
||||
message StopBundleCaptureRequest {}
|
||||
message StopBundleCaptureResponse {}
|
||||
|
||||
@@ -37,6 +37,9 @@ const (
|
||||
DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState"
|
||||
DaemonService_SetSyncResponsePersistence_FullMethodName = "/daemon.DaemonService/SetSyncResponsePersistence"
|
||||
DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket"
|
||||
DaemonService_StartCapture_FullMethodName = "/daemon.DaemonService/StartCapture"
|
||||
DaemonService_StartBundleCapture_FullMethodName = "/daemon.DaemonService/StartBundleCapture"
|
||||
DaemonService_StopBundleCapture_FullMethodName = "/daemon.DaemonService/StopBundleCapture"
|
||||
DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents"
|
||||
DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents"
|
||||
DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile"
|
||||
@@ -96,6 +99,14 @@ type DaemonServiceClient interface {
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||
SetSyncResponsePersistence(ctx context.Context, in *SetSyncResponsePersistenceRequest, opts ...grpc.CallOption) (*SetSyncResponsePersistenceResponse, error)
|
||||
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
|
||||
// StartCapture begins streaming packet capture on the WireGuard interface.
|
||||
// Requires --enable-capture set at service install/reconfigure time.
|
||||
StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[CapturePacket], error)
|
||||
// StartBundleCapture begins capturing packets to a server-side temp file
|
||||
// for inclusion in the next debug bundle. Auto-stops after the given timeout.
|
||||
StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error)
|
||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
||||
StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error)
|
||||
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error)
|
||||
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
|
||||
SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error)
|
||||
@@ -313,9 +324,48 @@ func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[CapturePacket], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_StartCapture_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &grpc.GenericClientStream[StartCaptureRequest, CapturePacket]{ClientStream: stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_StartCaptureClient = grpc.ServerStreamingClient[CapturePacket]
|
||||
|
||||
func (c *daemonServiceClient) StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(StartBundleCaptureResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_StartBundleCapture_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(StopBundleCaptureResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_StopBundleCapture_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_SubscribeEvents_FullMethodName, cOpts...)
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], DaemonService_SubscribeEvents_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -494,7 +544,7 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal
|
||||
|
||||
func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], DaemonService_ExposeService_FullMethodName, cOpts...)
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[2], DaemonService_ExposeService_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -550,6 +600,14 @@ type DaemonServiceServer interface {
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||
SetSyncResponsePersistence(context.Context, *SetSyncResponsePersistenceRequest) (*SetSyncResponsePersistenceResponse, error)
|
||||
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
|
||||
// StartCapture begins streaming packet capture on the WireGuard interface.
|
||||
// Requires --enable-capture set at service install/reconfigure time.
|
||||
StartCapture(*StartCaptureRequest, grpc.ServerStreamingServer[CapturePacket]) error
|
||||
// StartBundleCapture begins capturing packets to a server-side temp file
|
||||
// for inclusion in the next debug bundle. Auto-stops after the given timeout.
|
||||
StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error)
|
||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
||||
StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error)
|
||||
SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error
|
||||
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
|
||||
SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error)
|
||||
@@ -641,6 +699,15 @@ func (UnimplementedDaemonServiceServer) SetSyncResponsePersistence(context.Conte
|
||||
func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method TracePacket not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) StartCapture(*StartCaptureRequest, grpc.ServerStreamingServer[CapturePacket]) error {
|
||||
return status.Error(codes.Unimplemented, "method StartCapture not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method StartBundleCapture not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method StopBundleCapture not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error {
|
||||
return status.Error(codes.Unimplemented, "method SubscribeEvents not implemented")
|
||||
}
|
||||
@@ -1040,6 +1107,53 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_StartCapture_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(StartCaptureRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(DaemonServiceServer).StartCapture(m, &grpc.GenericServerStream[StartCaptureRequest, CapturePacket]{ServerStream: stream})
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_StartCaptureServer = grpc.ServerStreamingServer[CapturePacket]
|
||||
|
||||
func _DaemonService_StartBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StartBundleCaptureRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).StartBundleCapture(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_StartBundleCapture_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).StartBundleCapture(ctx, req.(*StartBundleCaptureRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_StopBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StopBundleCaptureRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).StopBundleCapture(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_StopBundleCapture_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).StopBundleCapture(ctx, req.(*StopBundleCaptureRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(SubscribeRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
@@ -1429,6 +1543,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "TracePacket",
|
||||
Handler: _DaemonService_TracePacket_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StartBundleCapture",
|
||||
Handler: _DaemonService_StartBundleCapture_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StopBundleCapture",
|
||||
Handler: _DaemonService_StopBundleCapture_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetEvents",
|
||||
Handler: _DaemonService_GetEvents_Handler,
|
||||
@@ -1495,6 +1617,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "StartCapture",
|
||||
Handler: _DaemonService_StartCapture_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "SubscribeEvents",
|
||||
Handler: _DaemonService_SubscribeEvents_Handler,
|
||||
|
||||
365
client/server/capture.go
Normal file
365
client/server/capture.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
)
|
||||
|
||||
const maxBundleCaptureDuration = 10 * time.Minute
|
||||
|
||||
// bundleCapture holds the state of an in-progress capture destined for the
|
||||
// debug bundle. The lifecycle is:
|
||||
//
|
||||
// StartBundleCapture → capture running, writing to temp file
|
||||
// StopBundleCapture → capture stopped, temp file available
|
||||
// DebugBundle → temp file included in zip, then cleaned up
|
||||
type bundleCapture struct {
|
||||
mu sync.Mutex
|
||||
sess *capture.Session
|
||||
file *os.File
|
||||
engine *internal.Engine
|
||||
cancel context.CancelFunc
|
||||
stopped bool
|
||||
}
|
||||
|
||||
// stop halts the capture session and closes the pcap writer. Idempotent.
|
||||
func (bc *bundleCapture) stop() {
|
||||
bc.mu.Lock()
|
||||
defer bc.mu.Unlock()
|
||||
|
||||
if bc.stopped {
|
||||
return
|
||||
}
|
||||
bc.stopped = true
|
||||
|
||||
if bc.cancel != nil {
|
||||
bc.cancel()
|
||||
}
|
||||
if bc.sess != nil {
|
||||
bc.sess.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// path returns the temp file path, or "" if no file exists.
|
||||
func (bc *bundleCapture) path() string {
|
||||
if bc.file == nil {
|
||||
return ""
|
||||
}
|
||||
return bc.file.Name()
|
||||
}
|
||||
|
||||
// cleanup removes the temp file.
|
||||
func (bc *bundleCapture) cleanup() {
|
||||
if bc.file == nil {
|
||||
return
|
||||
}
|
||||
name := bc.file.Name()
|
||||
if err := bc.file.Close(); err != nil {
|
||||
log.Debugf("close bundle capture file: %v", err)
|
||||
}
|
||||
if err := os.Remove(name); err != nil && !os.IsNotExist(err) {
|
||||
log.Debugf("remove bundle capture file: %v", err)
|
||||
}
|
||||
bc.file = nil
|
||||
}
|
||||
|
||||
// StartCapture streams a pcap or text packet capture over gRPC.
|
||||
// Gated by the --enable-capture service flag.
|
||||
func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.DaemonService_StartCaptureServer) error {
|
||||
if !s.captureEnabled {
|
||||
return status.Error(codes.PermissionDenied,
|
||||
"packet capture is disabled; reinstall or reconfigure the service with --enable-capture")
|
||||
}
|
||||
|
||||
if d := req.GetDuration(); d != nil && d.AsDuration() < 0 {
|
||||
return status.Error(codes.InvalidArgument, "duration must not be negative")
|
||||
}
|
||||
|
||||
matcher, err := parseCaptureFilter(req)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
opts := capture.Options{
|
||||
Matcher: matcher,
|
||||
SnapLen: req.GetSnapLen(),
|
||||
Verbose: req.GetVerbose(),
|
||||
ASCII: req.GetAscii(),
|
||||
}
|
||||
if req.GetTextOutput() {
|
||||
opts.TextOutput = pw
|
||||
} else {
|
||||
opts.Output = pw
|
||||
}
|
||||
|
||||
sess, err := capture.NewSession(opts)
|
||||
if err != nil {
|
||||
pw.Close()
|
||||
return status.Errorf(codes.Internal, "create capture session: %v", err)
|
||||
}
|
||||
|
||||
engine, err := s.claimCapture(sess)
|
||||
if err != nil {
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := engine.SetCapture(sess); err != nil {
|
||||
s.releaseCapture(sess)
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
return status.Errorf(codes.Internal, "set capture: %v", err)
|
||||
}
|
||||
|
||||
// Send an empty initial message to signal that the capture was accepted.
|
||||
// The client waits for this before printing the banner, so it must arrive
|
||||
// before any packet data.
|
||||
if err := stream.Send(&proto.CapturePacket{}); err != nil {
|
||||
s.clearCaptureIfOwner(sess, engine)
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
return status.Errorf(codes.Internal, "send initial message: %v", err)
|
||||
}
|
||||
|
||||
ctx := stream.Context()
|
||||
if d := req.GetDuration(); d != nil {
|
||||
if dur := d.AsDuration(); dur > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, dur)
|
||||
defer cancel()
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
s.clearCaptureIfOwner(sess, engine)
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
}()
|
||||
defer pr.Close()
|
||||
|
||||
log.Infof("packet capture started (text=%v, expr=%q)", req.GetTextOutput(), req.GetFilterExpr())
|
||||
defer func() {
|
||||
stats := sess.Stats()
|
||||
log.Infof("packet capture stopped: %d packets, %d bytes, %d dropped",
|
||||
stats.Packets, stats.Bytes, stats.Dropped)
|
||||
}()
|
||||
|
||||
return streamToGRPC(pr, stream)
|
||||
}
|
||||
|
||||
func streamToGRPC(r io.Reader, stream proto.DaemonService_StartCaptureServer) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, readErr := r.Read(buf)
|
||||
if n > 0 {
|
||||
if err := stream.Send(&proto.CapturePacket{Data: buf[:n]}); err != nil {
|
||||
log.Debugf("capture stream send: %v", err)
|
||||
return nil //nolint:nilerr // client disconnected
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
return nil //nolint:nilerr // pipe closed, capture stopped normally
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartBundleCapture begins capturing packets to a server-side temp file for
|
||||
// inclusion in the next debug bundle. Not gated by --enable-capture since the
|
||||
// output stays on the server (same trust level as CPU profiling).
|
||||
//
|
||||
// A timeout auto-stops the capture as a safety net if StopBundleCapture is
|
||||
// never called (e.g. CLI crash).
|
||||
func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCaptureRequest) (*proto.StartBundleCaptureResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.stopBundleCaptureLocked()
|
||||
s.cleanupBundleCapture()
|
||||
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
// Not fatal: kernel mode or not connected. Log and return success
|
||||
// so the debug bundle still generates without capture data.
|
||||
log.Warnf("packet capture unavailable, skipping: %v", err)
|
||||
return &proto.StartBundleCaptureResponse{}, nil
|
||||
}
|
||||
|
||||
timeout := req.GetTimeout().AsDuration()
|
||||
if timeout <= 0 || timeout > maxBundleCaptureDuration {
|
||||
timeout = maxBundleCaptureDuration
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "netbird.capture.*.pcap")
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "create temp file: %v", err)
|
||||
}
|
||||
|
||||
sess, err := capture.NewSession(capture.Options{Output: f})
|
||||
if err != nil {
|
||||
f.Close()
|
||||
os.Remove(f.Name())
|
||||
return nil, status.Errorf(codes.Internal, "create capture session: %v", err)
|
||||
}
|
||||
|
||||
if err := engine.SetCapture(sess); err != nil {
|
||||
sess.Stop()
|
||||
f.Close()
|
||||
os.Remove(f.Name())
|
||||
log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err)
|
||||
return &proto.StartBundleCaptureResponse{}, nil
|
||||
}
|
||||
s.activeCapture = sess
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
bc := &bundleCapture{
|
||||
sess: sess,
|
||||
file: f,
|
||||
engine: engine,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.bundleCapture = bc
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
s.mutex.Lock()
|
||||
if s.bundleCapture == bc {
|
||||
s.stopBundleCaptureLocked()
|
||||
} else {
|
||||
bc.stop()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
log.Infof("bundle capture auto-stopped after timeout")
|
||||
}()
|
||||
log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name())
|
||||
|
||||
return &proto.StartBundleCaptureResponse{}, nil
|
||||
}
|
||||
|
||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
||||
func (s *Server) StopBundleCapture(_ context.Context, _ *proto.StopBundleCaptureRequest) (*proto.StopBundleCaptureResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.stopBundleCaptureLocked()
|
||||
return &proto.StopBundleCaptureResponse{}, nil
|
||||
}
|
||||
|
||||
// stopBundleCaptureLocked stops the bundle capture if running. Must hold s.mutex.
|
||||
func (s *Server) stopBundleCaptureLocked() {
|
||||
if s.bundleCapture == nil {
|
||||
return
|
||||
}
|
||||
bc := s.bundleCapture
|
||||
if bc.engine != nil && s.activeCapture == bc.sess {
|
||||
if err := bc.engine.SetCapture(nil); err != nil {
|
||||
log.Debugf("clear bundle capture: %v", err)
|
||||
}
|
||||
s.activeCapture = nil
|
||||
}
|
||||
bc.stop()
|
||||
|
||||
stats := bc.sess.Stats()
|
||||
log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped",
|
||||
stats.Packets, stats.Bytes, stats.Dropped)
|
||||
}
|
||||
|
||||
// bundleCapturePath returns the temp file path if a capture has been taken,
|
||||
// stops any running capture, and returns "". Called from DebugBundle.
|
||||
// Must hold s.mutex.
|
||||
func (s *Server) bundleCapturePath() string {
|
||||
if s.bundleCapture == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
s.bundleCapture.stop()
|
||||
return s.bundleCapture.path()
|
||||
}
|
||||
|
||||
// cleanupBundleCapture removes the temp file and clears state. Must hold s.mutex.
|
||||
func (s *Server) cleanupBundleCapture() {
|
||||
if s.bundleCapture == nil {
|
||||
return
|
||||
}
|
||||
s.bundleCapture.cleanup()
|
||||
s.bundleCapture = nil
|
||||
}
|
||||
|
||||
// claimCapture reserves the engine's capture slot for sess. Returns
|
||||
// FailedPrecondition if another capture is already active.
|
||||
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.activeCapture = sess
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// releaseCapture clears the active-capture owner if it still matches sess.
|
||||
func (s *Server) releaseCapture(sess *capture.Session) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.activeCapture == sess {
|
||||
s.activeCapture = nil
|
||||
}
|
||||
}
|
||||
|
||||
// clearCaptureIfOwner clears engine's capture slot only if sess still owns it.
|
||||
func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Engine) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.activeCapture != sess {
|
||||
return
|
||||
}
|
||||
if err := engine.SetCapture(nil); err != nil {
|
||||
log.Debugf("clear capture: %v", err)
|
||||
}
|
||||
s.activeCapture = nil
|
||||
}
|
||||
|
||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "client not connected")
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")
|
||||
}
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// parseCaptureFilter returns a Matcher from the request.
|
||||
// Returns nil (match all) when no filter expression is set.
|
||||
func parseCaptureFilter(req *proto.StartCaptureRequest) (capture.Matcher, error) {
|
||||
expr := req.GetFilterExpr()
|
||||
if expr == "" {
|
||||
return nil, nil //nolint:nilnil // nil Matcher means "match all"
|
||||
}
|
||||
return capture.ParseFilter(expr)
|
||||
}
|
||||
@@ -43,7 +43,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
}()
|
||||
}
|
||||
|
||||
// Prepare refresh callback for health probes
|
||||
capturePath := s.bundleCapturePath()
|
||||
defer s.cleanupBundleCapture()
|
||||
|
||||
var refreshStatus func()
|
||||
if s.connectClient != nil {
|
||||
engine := s.connectClient.Engine()
|
||||
@@ -62,6 +64,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: s.logFile,
|
||||
CPUProfile: cpuProfileData,
|
||||
CapturePath: capturePath,
|
||||
RefreshStatus: refreshStatus,
|
||||
ClientMetrics: clientMetrics,
|
||||
},
|
||||
|
||||
@@ -33,6 +33,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -89,7 +90,11 @@ type Server struct {
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
networksDisabled bool
|
||||
captureEnabled bool
|
||||
bundleCapture *bundleCapture
|
||||
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
|
||||
activeCapture *capture.Session
|
||||
networksDisabled bool
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
@@ -106,7 +111,7 @@ type oauthAuthFlow struct {
|
||||
}
|
||||
|
||||
// New server instance constructor.
|
||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
|
||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool, networksDisabled bool) *Server {
|
||||
s := &Server{
|
||||
rootCtx: ctx,
|
||||
logFile: logFile,
|
||||
@@ -115,6 +120,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
profileManager: profilemanager.NewServiceManager(configFile),
|
||||
profilesDisabled: profilesDisabled,
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
captureEnabled: captureEnabled,
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
@@ -1478,6 +1484,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
disableClientRoutes := cfg.DisableClientRoutes
|
||||
disableServerRoutes := cfg.DisableServerRoutes
|
||||
blockLANAccess := cfg.BlockLANAccess
|
||||
disableFirewall := cfg.DisableFirewall
|
||||
|
||||
enableSSHRoot := false
|
||||
if cfg.EnableSSHRoot != nil {
|
||||
@@ -1534,6 +1541,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
DisableFirewall: disableFirewall,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false)
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
@@ -165,7 +165,7 @@ func TestServer_Up(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
s := New(ctx, "console", "", false, false, false, false)
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -235,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
s := New(ctx, "console", "", false, false, false, false)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
s := New(ctx, "console", "", false, false, false, false)
|
||||
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
|
||||
@@ -224,15 +224,20 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
sshConfigPathTmp := sshConfigPath + ".tmp"
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
|
||||
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"fyne.io/fyne/v2/widget"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -38,6 +39,7 @@ type debugCollectionParams struct {
|
||||
upload bool
|
||||
uploadURL string
|
||||
enablePersistence bool
|
||||
capture bool
|
||||
}
|
||||
|
||||
// UI components for progress tracking
|
||||
@@ -51,25 +53,58 @@ type progressUI struct {
|
||||
func (s *serviceClient) showDebugUI() {
|
||||
w := s.app.NewWindow("NetBird Debug")
|
||||
w.SetOnClosed(s.cancel)
|
||||
|
||||
w.Resize(fyne.NewSize(600, 500))
|
||||
w.SetFixedSize(true)
|
||||
|
||||
anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil)
|
||||
systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil)
|
||||
systemInfoCheck.SetChecked(true)
|
||||
captureCheck := widget.NewCheck("Include packet capture", nil)
|
||||
uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil)
|
||||
uploadCheck.SetChecked(true)
|
||||
|
||||
uploadURLLabel := widget.NewLabel("Debug upload URL:")
|
||||
uploadURLContainer, uploadURL := s.buildUploadSection(uploadCheck)
|
||||
|
||||
debugModeContainer, runForDurationCheck, durationInput, noteLabel := s.buildDurationSection()
|
||||
|
||||
statusLabel := widget.NewLabel("")
|
||||
statusLabel.Hide()
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.Hide()
|
||||
createButton := widget.NewButton("Create Debug Bundle", nil)
|
||||
|
||||
uiControls := []fyne.Disableable{
|
||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
||||
uploadCheck, uploadURL, runForDurationCheck, durationInput, createButton,
|
||||
}
|
||||
|
||||
createButton.OnTapped = s.getCreateHandler(
|
||||
statusLabel, progressBar, uploadCheck, uploadURL,
|
||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
||||
runForDurationCheck, durationInput, uiControls, w,
|
||||
)
|
||||
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
|
||||
widget.NewLabel(""),
|
||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
||||
uploadCheck, uploadURLContainer,
|
||||
widget.NewLabel(""),
|
||||
debugModeContainer, noteLabel,
|
||||
widget.NewLabel(""),
|
||||
statusLabel, progressBar, createButton,
|
||||
)
|
||||
|
||||
w.SetContent(container.NewPadded(content))
|
||||
w.Show()
|
||||
}
|
||||
|
||||
func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Container, *widget.Entry) {
|
||||
uploadURL := widget.NewEntry()
|
||||
uploadURL.SetText(uptypes.DefaultBundleURL)
|
||||
uploadURL.SetPlaceHolder("Enter upload URL")
|
||||
|
||||
uploadURLContainer := container.NewVBox(
|
||||
uploadURLLabel,
|
||||
uploadURL,
|
||||
)
|
||||
uploadURLContainer := container.NewVBox(widget.NewLabel("Debug upload URL:"), uploadURL)
|
||||
|
||||
uploadCheck.OnChanged = func(checked bool) {
|
||||
if checked {
|
||||
@@ -78,13 +113,14 @@ func (s *serviceClient) showDebugUI() {
|
||||
uploadURLContainer.Hide()
|
||||
}
|
||||
}
|
||||
return uploadURLContainer, uploadURL
|
||||
}
|
||||
|
||||
debugModeContainer := container.NewHBox()
|
||||
func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check, *widget.Entry, *widget.Label) {
|
||||
runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil)
|
||||
runForDurationCheck.SetChecked(true)
|
||||
|
||||
forLabel := widget.NewLabel("for")
|
||||
|
||||
durationInput := widget.NewEntry()
|
||||
durationInput.SetText("1")
|
||||
minutesLabel := widget.NewLabel("minute")
|
||||
@@ -108,63 +144,8 @@ func (s *serviceClient) showDebugUI() {
|
||||
}
|
||||
}
|
||||
|
||||
debugModeContainer.Add(runForDurationCheck)
|
||||
debugModeContainer.Add(forLabel)
|
||||
debugModeContainer.Add(durationInput)
|
||||
debugModeContainer.Add(minutesLabel)
|
||||
|
||||
statusLabel := widget.NewLabel("")
|
||||
statusLabel.Hide()
|
||||
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.Hide()
|
||||
|
||||
createButton := widget.NewButton("Create Debug Bundle", nil)
|
||||
|
||||
// UI controls that should be disabled during debug collection
|
||||
uiControls := []fyne.Disableable{
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
uploadCheck,
|
||||
uploadURL,
|
||||
runForDurationCheck,
|
||||
durationInput,
|
||||
createButton,
|
||||
}
|
||||
|
||||
createButton.OnTapped = s.getCreateHandler(
|
||||
statusLabel,
|
||||
progressBar,
|
||||
uploadCheck,
|
||||
uploadURL,
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
runForDurationCheck,
|
||||
durationInput,
|
||||
uiControls,
|
||||
w,
|
||||
)
|
||||
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
|
||||
widget.NewLabel(""),
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
uploadCheck,
|
||||
uploadURLContainer,
|
||||
widget.NewLabel(""),
|
||||
debugModeContainer,
|
||||
noteLabel,
|
||||
widget.NewLabel(""),
|
||||
statusLabel,
|
||||
progressBar,
|
||||
createButton,
|
||||
)
|
||||
|
||||
paddedContent := container.NewPadded(content)
|
||||
w.SetContent(paddedContent)
|
||||
|
||||
w.Show()
|
||||
modeContainer := container.NewHBox(runForDurationCheck, forLabel, durationInput, minutesLabel)
|
||||
return modeContainer, runForDurationCheck, durationInput, noteLabel
|
||||
}
|
||||
|
||||
func validateMinute(s string, minutesLabel *widget.Label) error {
|
||||
@@ -200,6 +181,7 @@ func (s *serviceClient) getCreateHandler(
|
||||
uploadURL *widget.Entry,
|
||||
anonymizeCheck *widget.Check,
|
||||
systemInfoCheck *widget.Check,
|
||||
captureCheck *widget.Check,
|
||||
runForDurationCheck *widget.Check,
|
||||
duration *widget.Entry,
|
||||
uiControls []fyne.Disableable,
|
||||
@@ -222,6 +204,7 @@ func (s *serviceClient) getCreateHandler(
|
||||
params := &debugCollectionParams{
|
||||
anonymize: anonymizeCheck.Checked,
|
||||
systemInfo: systemInfoCheck.Checked,
|
||||
capture: captureCheck.Checked,
|
||||
upload: uploadCheck.Checked,
|
||||
uploadURL: url,
|
||||
enablePersistence: true,
|
||||
@@ -253,10 +236,7 @@ func (s *serviceClient) getCreateHandler(
|
||||
|
||||
statusLabel.SetText("Creating debug bundle...")
|
||||
go s.handleDebugCreation(
|
||||
anonymizeCheck.Checked,
|
||||
systemInfoCheck.Checked,
|
||||
uploadCheck.Checked,
|
||||
url,
|
||||
params,
|
||||
statusLabel,
|
||||
uiControls,
|
||||
w,
|
||||
@@ -371,7 +351,7 @@ func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time
|
||||
func (s *serviceClient) configureServiceForDebug(
|
||||
conn proto.DaemonServiceClient,
|
||||
state *debugInitialState,
|
||||
enablePersistence bool,
|
||||
params *debugCollectionParams,
|
||||
) {
|
||||
if state.wasDown {
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
@@ -397,7 +377,7 @@ func (s *serviceClient) configureServiceForDebug(
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
if enablePersistence {
|
||||
if params.enablePersistence {
|
||||
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
@@ -417,6 +397,26 @@ func (s *serviceClient) configureServiceForDebug(
|
||||
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
|
||||
log.Warnf("failed to start CPU profiling: %v", err)
|
||||
}
|
||||
|
||||
s.startBundleCaptureIfEnabled(conn, params)
|
||||
}
|
||||
|
||||
func (s *serviceClient) startBundleCaptureIfEnabled(conn proto.DaemonServiceClient, params *debugCollectionParams) {
|
||||
if !params.capture {
|
||||
return
|
||||
}
|
||||
|
||||
const maxCapture = 10 * time.Minute
|
||||
timeout := params.duration + 30*time.Second
|
||||
if timeout > maxCapture {
|
||||
timeout = maxCapture
|
||||
log.Warnf("packet capture clamped to %s (server maximum)", maxCapture)
|
||||
}
|
||||
if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{
|
||||
Timeout: durationpb.New(timeout),
|
||||
}); err != nil {
|
||||
log.Warnf("failed to start bundle capture: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) collectDebugData(
|
||||
@@ -430,7 +430,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var wg sync.WaitGroup
|
||||
startProgressTracker(ctx, &wg, params.duration, progress)
|
||||
|
||||
s.configureServiceForDebug(conn, state, params.enablePersistence)
|
||||
s.configureServiceForDebug(conn, state, params)
|
||||
|
||||
wg.Wait()
|
||||
progress.progressBar.Hide()
|
||||
@@ -440,6 +440,14 @@ func (s *serviceClient) collectDebugData(
|
||||
log.Warnf("failed to stop CPU profiling: %v", err)
|
||||
}
|
||||
|
||||
if params.capture {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
log.Warnf("failed to stop bundle capture: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -520,18 +528,37 @@ func handleError(progress *progressUI, errMsg string) {
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleDebugCreation(
|
||||
anonymize bool,
|
||||
systemInfo bool,
|
||||
upload bool,
|
||||
uploadURL string,
|
||||
params *debugCollectionParams,
|
||||
statusLabel *widget.Label,
|
||||
uiControls []fyne.Disableable,
|
||||
w fyne.Window,
|
||||
) {
|
||||
log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...",
|
||||
anonymize, systemInfo, upload)
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get client for debug: %v", err)
|
||||
statusLabel.SetText(fmt.Sprintf("Error: %v", err))
|
||||
enableUIControls(uiControls)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL)
|
||||
if params.capture {
|
||||
if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{
|
||||
Timeout: durationpb.New(30 * time.Second),
|
||||
}); err != nil {
|
||||
log.Warnf("failed to start bundle capture: %v", err)
|
||||
} else {
|
||||
defer func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
||||
log.Warnf("failed to stop bundle capture: %v", err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s.createDebugBundle(params.anonymize, params.systemInfo, params.uploadURL)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create debug bundle: %v", err)
|
||||
statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err))
|
||||
@@ -543,7 +570,7 @@ func (s *serviceClient) handleDebugCreation(
|
||||
uploadFailureReason := resp.GetUploadFailureReason()
|
||||
uploadedKey := resp.GetUploadedKey()
|
||||
|
||||
if upload {
|
||||
if params.upload {
|
||||
if uploadFailureReason != "" {
|
||||
showUploadFailedDialog(w, localPath, uploadFailureReason)
|
||||
} else {
|
||||
|
||||
@@ -5,6 +5,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
wasmcapture "github.com/netbirdio/netbird/client/wasm/internal/capture"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
@@ -459,6 +461,95 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// createStartCaptureMethod creates the programmable packet capture method.
|
||||
// Returns a JS interface with onpacket callback and stop() method.
|
||||
//
|
||||
// Usage from JavaScript:
|
||||
//
|
||||
// const cap = await client.startCapture({ filter: "tcp port 443", verbose: true })
|
||||
// cap.onpacket = (line) => console.log(line)
|
||||
// const stats = cap.stop()
|
||||
func createStartCaptureMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
var opts js.Value
|
||||
if len(args) > 0 {
|
||||
opts = args[0]
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
iface, err := wasmcapture.Start(client, opts)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err)))
|
||||
return
|
||||
}
|
||||
resolve.Invoke(iface)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// captureMethods returns capture() and stopCapture() that share state for
|
||||
// the console-log shortcut. capture() logs packets to the browser console
|
||||
// and stopCapture() ends it, like Ctrl+C on the CLI.
|
||||
//
|
||||
// Usage from browser devtools console:
|
||||
//
|
||||
// await client.capture() // capture all packets
|
||||
// await client.capture("tcp") // capture with filter
|
||||
// await client.capture({filter: "host 10.0.0.1", verbose: true})
|
||||
// client.stopCapture() // stop and print stats
|
||||
func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) {
|
||||
var mu sync.Mutex
|
||||
var active *wasmcapture.Handle
|
||||
|
||||
startFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
var opts js.Value
|
||||
if len(args) > 0 {
|
||||
opts = args[0]
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if active != nil {
|
||||
active.Stop()
|
||||
active = nil
|
||||
}
|
||||
|
||||
h, err := wasmcapture.StartConsole(client, opts)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err)))
|
||||
return
|
||||
}
|
||||
active = h
|
||||
|
||||
console := js.Global().Get("console")
|
||||
console.Call("log", "[capture] started, call client.stopCapture() to stop")
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
|
||||
stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if active == nil {
|
||||
js.Global().Get("console").Call("log", "[capture] no active capture")
|
||||
return js.Undefined()
|
||||
}
|
||||
|
||||
stats := active.Stop()
|
||||
active = nil
|
||||
|
||||
console := js.Global().Get("console")
|
||||
console.Call("log", fmt.Sprintf("[capture] stopped: %d packets, %d bytes, %d dropped",
|
||||
stats.Packets, stats.Bytes, stats.Dropped))
|
||||
return js.Undefined()
|
||||
})
|
||||
|
||||
return startFn, stopFn
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
@@ -521,6 +612,11 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||
obj["startCapture"] = createStartCaptureMethod(client)
|
||||
|
||||
capStart, capStop := captureMethods(client)
|
||||
obj["capture"] = capStart
|
||||
obj["stopCapture"] = capStop
|
||||
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
176
client/wasm/internal/capture/capture.go
Normal file
176
client/wasm/internal/capture/capture.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build js
|
||||
|
||||
// Package capture bridges the util/capture package to JavaScript via syscall/js.
|
||||
package capture
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
// Handle holds a running capture session so it can be stopped later.
|
||||
type Handle struct {
|
||||
cs *netbird.CaptureSession
|
||||
stopFn js.Func
|
||||
stopped bool
|
||||
}
|
||||
|
||||
// Stop ends the capture and returns stats.
|
||||
func (h *Handle) Stop() netbird.CaptureStats {
|
||||
if h.stopped {
|
||||
return h.cs.Stats()
|
||||
}
|
||||
h.stopped = true
|
||||
h.stopFn.Release()
|
||||
|
||||
h.cs.Stop()
|
||||
return h.cs.Stats()
|
||||
}
|
||||
|
||||
func statsToJS(s netbird.CaptureStats) js.Value {
|
||||
obj := js.Global().Get("Object").Call("create", js.Null())
|
||||
obj.Set("packets", js.ValueOf(s.Packets))
|
||||
obj.Set("bytes", js.ValueOf(s.Bytes))
|
||||
obj.Set("dropped", js.ValueOf(s.Dropped))
|
||||
return obj
|
||||
}
|
||||
|
||||
// parseOpts extracts filter/verbose/ascii from a JS options value.
|
||||
func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) {
|
||||
if jsOpts.IsNull() || jsOpts.IsUndefined() {
|
||||
return
|
||||
}
|
||||
if jsOpts.Type() == js.TypeString {
|
||||
filter = jsOpts.String()
|
||||
return
|
||||
}
|
||||
if jsOpts.Type() != js.TypeObject {
|
||||
return
|
||||
}
|
||||
if f := jsOpts.Get("filter"); !f.IsUndefined() && !f.IsNull() {
|
||||
filter = f.String()
|
||||
}
|
||||
if v := jsOpts.Get("verbose"); !v.IsUndefined() {
|
||||
verbose = v.Truthy()
|
||||
}
|
||||
if a := jsOpts.Get("ascii"); !a.IsUndefined() {
|
||||
ascii = a.Truthy()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Start creates a capture session and returns a JS interface for streaming text
|
||||
// output. The returned object exposes:
|
||||
//
|
||||
// onpacket(callback) - set callback(string) for each text line
|
||||
// stop() - stop capture and return stats { packets, bytes, dropped }
|
||||
//
|
||||
// Options: { filter: string, verbose: bool, ascii: bool } or just a filter string.
|
||||
func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) {
|
||||
filter, verbose, ascii := parseOpts(jsOpts)
|
||||
|
||||
cb := &jsCallbackWriter{}
|
||||
|
||||
cs, err := client.StartCapture(netbird.CaptureOptions{
|
||||
TextOutput: cb,
|
||||
Filter: filter,
|
||||
Verbose: verbose,
|
||||
ASCII: ascii,
|
||||
})
|
||||
if err != nil {
|
||||
return js.Undefined(), err
|
||||
}
|
||||
|
||||
handle := &Handle{cs: cs}
|
||||
|
||||
iface := js.Global().Get("Object").Call("create", js.Null())
|
||||
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
return statsToJS(handle.Stop())
|
||||
})
|
||||
iface.Set("stop", handle.stopFn)
|
||||
iface.Set("onpacket", js.Undefined())
|
||||
cb.setInterface(iface)
|
||||
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
// StartConsole starts a capture that logs every packet line to console.log.
|
||||
// Returns a Handle so the caller can stop it later.
|
||||
func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) {
|
||||
filter, verbose, ascii := parseOpts(jsOpts)
|
||||
|
||||
cb := &jsCallbackWriter{}
|
||||
|
||||
cs, err := client.StartCapture(netbird.CaptureOptions{
|
||||
TextOutput: cb,
|
||||
Filter: filter,
|
||||
Verbose: verbose,
|
||||
ASCII: ascii,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handle := &Handle{cs: cs}
|
||||
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
return statsToJS(handle.Stop())
|
||||
})
|
||||
|
||||
iface := js.Global().Get("Object").Call("create", js.Null())
|
||||
console := js.Global().Get("console")
|
||||
iface.Set("onpacket", console.Get("log").Call("bind", console, js.ValueOf("[capture]")))
|
||||
cb.setInterface(iface)
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
// jsCallbackWriter is an io.Writer that buffers text until a newline, then
|
||||
// invokes the JS onpacket callback with each complete line.
|
||||
type jsCallbackWriter struct {
|
||||
mu sync.Mutex
|
||||
iface js.Value
|
||||
buf strings.Builder
|
||||
}
|
||||
|
||||
func (w *jsCallbackWriter) setInterface(iface js.Value) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.iface = iface
|
||||
}
|
||||
|
||||
func (w *jsCallbackWriter) Write(p []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
w.buf.Write(p)
|
||||
|
||||
var lines []string
|
||||
for {
|
||||
str := w.buf.String()
|
||||
idx := strings.IndexByte(str, '\n')
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
lines = append(lines, str[:idx])
|
||||
w.buf.Reset()
|
||||
if idx+1 < len(str) {
|
||||
w.buf.WriteString(str[idx+1:])
|
||||
}
|
||||
}
|
||||
|
||||
iface := w.iface
|
||||
w.mu.Unlock()
|
||||
|
||||
if iface.IsUndefined() {
|
||||
return len(p), nil
|
||||
}
|
||||
cb := iface.Get("onpacket")
|
||||
if cb.IsUndefined() || cb.IsNull() {
|
||||
return len(p), nil
|
||||
}
|
||||
for _, line := range lines {
|
||||
cb.Invoke(js.ValueOf(line))
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -13,11 +13,9 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
@@ -301,12 +299,11 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// isContextDone reports whether the local context has been canceled or has
|
||||
// exceeded its deadline. It deliberately does not inspect gRPC status codes:
|
||||
// a server- or proxy-sent codes.Canceled / codes.DeadlineExceeded must not
|
||||
// short-circuit our retry loop, since retrying is the correct response when
|
||||
// the local context is still alive.
|
||||
func isContextDone(err error) bool {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
|
||||
}
|
||||
return false
|
||||
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -71,6 +71,7 @@ require (
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||
@@ -308,6 +309,7 @@ require (
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
rsc.io/qr v0.2.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
|
||||
4
go.sum
4
go.sum
@@ -415,6 +415,8 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
|
||||
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
|
||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
||||
@@ -915,3 +917,5 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||
|
||||
@@ -257,7 +257,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
||||
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -331,9 +334,13 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
@@ -348,14 +355,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ const (
|
||||
)
|
||||
|
||||
type Controller interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
|
||||
@@ -44,17 +44,17 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers mocks base method.
|
||||
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// CountStreams mocks base method.
|
||||
@@ -238,15 +238,15 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId a
|
||||
}
|
||||
|
||||
// UpdateAccountPeers mocks base method.
|
||||
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
|
||||
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
|
||||
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (a *MockAccountManager) GetDeletePeerCalls() int {
|
||||
return a.deletePeerCalls
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
@@ -248,7 +248,7 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
mockAM.BufferUpdateAccountPeers(ctx, accountID)
|
||||
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{})
|
||||
return nil
|
||||
}).
|
||||
Times(1)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -178,7 +179,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
}
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -231,7 +232,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate})
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@@ -515,7 +516,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -819,7 +820,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -860,7 +861,7 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -916,7 +917,7 @@ func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1098,7 +1099,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
|
||||
}
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate})
|
||||
|
||||
serviceURL := "https://" + svc.Domain
|
||||
if service.IsL4Protocol(svc.Mode) {
|
||||
@@ -1210,7 +1211,7 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1261,7 +1262,7 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -447,7 +447,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
||||
storedActivity = activityID.(activity.Activity)
|
||||
},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
@@ -549,7 +549,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
||||
storedActivity = activityID.(activity.Activity)
|
||||
},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
@@ -593,7 +593,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) {
|
||||
storedMeta = meta
|
||||
},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
@@ -704,7 +704,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
@@ -1173,7 +1173,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
mockAcct.EXPECT().
|
||||
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
||||
mockAcct.EXPECT().
|
||||
UpdateAccountPeers(ctx, accountID)
|
||||
UpdateAccountPeers(ctx, accountID, gomock.Any())
|
||||
|
||||
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -144,7 +145,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
@@ -206,7 +207,7 @@ func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -95,7 +96,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationCreate})
|
||||
|
||||
return record, nil
|
||||
}
|
||||
@@ -154,7 +155,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return record, nil
|
||||
}
|
||||
@@ -201,7 +202,7 @@ func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneI
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -400,7 +400,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
|
||||
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
||||
go am.UpdateAccountPeers(ctx, accountID)
|
||||
go am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceAccountSettings, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return newSettings, nil
|
||||
@@ -1581,7 +1581,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -124,8 +124,8 @@ type Manager interface {
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string)
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
|
||||
GetStore() store.Store
|
||||
|
||||
@@ -111,15 +111,15 @@ func (mr *MockManagerMockRecorder) ApproveUser(ctx, accountID, initiatorUserID,
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers mocks base method.
|
||||
func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
|
||||
m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
|
||||
func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// BuildUserInfosForAccount mocks base method.
|
||||
@@ -1597,15 +1597,15 @@ func (mr *MockManagerMockRecorder) UpdateAccountOnboarding(ctx, accountID, userI
|
||||
}
|
||||
|
||||
// UpdateAccountPeers mocks base method.
|
||||
func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
|
||||
m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
|
||||
func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks base method.
|
||||
|
||||
8
management/server/account/pat.go
Normal file
8
management/server/account/pat.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package account
|
||||
|
||||
const (
|
||||
// PATMinExpireDays is the minimum allowed Personal Access Token expiration in days.
|
||||
PATMinExpireDays = 1
|
||||
// PATMaxExpireDays is the maximum allowed Personal Access Token expiration in days.
|
||||
PATMaxExpireDays = 365
|
||||
)
|
||||
@@ -1761,7 +1761,7 @@ func hasNilField(x interface{}) error {
|
||||
if f := rv.Field(i); f.IsValid() {
|
||||
k := f.Kind()
|
||||
switch k {
|
||||
case reflect.Ptr:
|
||||
case reflect.Pointer:
|
||||
if f.IsNil() {
|
||||
return fmt.Errorf("field %s is nil", f.String())
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -117,7 +117,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -185,7 +185,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -253,7 +253,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
return globalErr
|
||||
@@ -321,7 +321,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return globalErr
|
||||
@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -531,7 +531,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -559,7 +559,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -597,7 +597,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -62,9 +62,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
apiPrefix = "/api"
|
||||
)
|
||||
const apiPrefix = "/api"
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
@@ -141,7 +139,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddEndpoints(instanceManager, accountManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -15,13 +16,15 @@ import (
|
||||
// handler handles the instance setup HTTP endpoints
|
||||
type handler struct {
|
||||
instanceManager nbinstance.Manager
|
||||
setupManager *nbinstance.SetupService
|
||||
}
|
||||
|
||||
// AddEndpoints registers the instance setup endpoints.
|
||||
// These endpoints bypass authentication for initial setup.
|
||||
func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
func AddEndpoints(instanceManager nbinstance.Manager, accountManager account.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
setupManager: nbinstance.NewSetupService(instanceManager, accountManager),
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS")
|
||||
@@ -55,24 +58,36 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
// setup creates the initial admin user for the instance.
|
||||
// This endpoint is unauthenticated but only works when setup is required.
|
||||
func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req api.SetupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name)
|
||||
result, err := h.setupManager.SetupOwner(ctx, req.Email, req.Password, req.Name, nbinstance.SetupOptions{
|
||||
CreatePAT: req.CreatePat != nil && *req.CreatePat,
|
||||
PATExpireInDays: req.PatExpireIn,
|
||||
})
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email)
|
||||
log.WithContext(ctx).Infof("instance setup completed: created user %s", req.Email)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, api.SetupResponse{
|
||||
UserId: userData.ID,
|
||||
Email: userData.Email,
|
||||
})
|
||||
resp := api.SetupResponse{
|
||||
UserId: result.User.ID,
|
||||
Email: result.User.Email,
|
||||
}
|
||||
|
||||
if result.PATPlainToken != "" {
|
||||
resp.PersonalAccessToken = &result.PATPlainToken
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
// getVersionInfo returns version information for NetBird components.
|
||||
|
||||
@@ -10,12 +10,18 @@ import (
|
||||
"net/mail"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbstore "github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -25,6 +31,7 @@ type mockInstanceManager struct {
|
||||
isSetupRequired bool
|
||||
isSetupRequiredFn func(ctx context.Context) (bool, error)
|
||||
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
rollbackSetupFn func(ctx context.Context, userID string) error
|
||||
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
|
||||
}
|
||||
|
||||
@@ -67,6 +74,13 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) RollbackSetup(ctx context.Context, userID string) error {
|
||||
if m.rollbackSetupFn != nil {
|
||||
return m.rollbackSetupFn(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
|
||||
if m.getVersionInfoFn != nil {
|
||||
return m.getVersionInfoFn(ctx)
|
||||
@@ -82,8 +96,12 @@ func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.V
|
||||
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
|
||||
|
||||
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
|
||||
return setupTestRouterWithPAT(manager, nil)
|
||||
}
|
||||
|
||||
func setupTestRouterWithPAT(manager nbinstance.Manager, accountManager account.Manager) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
AddEndpoints(manager, router)
|
||||
AddEndpoints(manager, accountManager, router)
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -161,6 +179,7 @@ func TestSetup_Success(t *testing.T) {
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-store", rec.Header().Get("Cache-Control"))
|
||||
|
||||
var response api.SetupResponse
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
@@ -293,6 +312,239 @@ func TestSetup_ManagerError(t *testing.T) {
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "false")
|
||||
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
// NB_SETUP_PAT_ENABLED=false: request fields must be silently ignored
|
||||
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
assert.Nil(t, response.PersonalAccessToken)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
assert.Nil(t, response.PersonalAccessToken)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_MissingExpireIn_DefaultsToOneDay(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
createCalled := false
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
createCalled = true
|
||||
return &idp.UserData{ID: "u1", Email: email, Name: name}, nil
|
||||
},
|
||||
}
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
assert.Equal(t, "u1", userAuth.UserId)
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
assert.Equal(t, "acc-1", accountID)
|
||||
assert.Equal(t, "u1", initiator)
|
||||
assert.Equal(t, "u1", target)
|
||||
assert.Equal(t, "setup-token", name)
|
||||
assert.Equal(t, 1, expiresIn)
|
||||
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
|
||||
},
|
||||
}
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-store", rec.Header().Get("Cache-Control"))
|
||||
assert.True(t, createCalled)
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
require.NotNil(t, response.PersonalAccessToken)
|
||||
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_ExpireOutOfRange(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_Success(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
}
|
||||
|
||||
gotAccountArgs := struct {
|
||||
userID string
|
||||
email string
|
||||
}{}
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
gotAccountArgs.userID = userAuth.UserId
|
||||
gotAccountArgs.email = userAuth.Email
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
assert.Equal(t, "acc-1", accountID)
|
||||
assert.Equal(t, "owner-id", initiator)
|
||||
assert.Equal(t, "owner-id", target)
|
||||
assert.Equal(t, "setup-token", name)
|
||||
assert.Equal(t, 30, expiresIn)
|
||||
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-store", rec.Header().Get("Cache-Control"))
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
assert.Equal(t, "owner-id", response.UserId)
|
||||
require.NotNil(t, response.PersonalAccessToken)
|
||||
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
|
||||
assert.Equal(t, "owner-id", gotAccountArgs.userID)
|
||||
assert.Equal(t, "admin@example.com", gotAccountArgs.email)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("", status.NewAccountNotFoundError("owner-id"))
|
||||
|
||||
rolledBackFor := ""
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "", errors.New("db down")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
}
|
||||
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called with the created user id")
|
||||
}
|
||||
|
||||
func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
account := &types.Account{Id: "acc-1"}
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
|
||||
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
|
||||
|
||||
rolledBackFor := ""
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
return nil, status.Errorf(status.Internal, "token store unavailable")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
}
|
||||
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called when CreatePAT fails")
|
||||
}
|
||||
|
||||
func TestGetVersionInfo_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{}
|
||||
router := mux.NewRouter()
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -60,6 +61,13 @@ type Manager interface {
|
||||
// This should only be called when IsSetupRequired returns true.
|
||||
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
|
||||
// RollbackSetup reverses a successful CreateOwnerUser by deleting the user
|
||||
// from the embedded IDP and reloading setupRequired from persistent state, so
|
||||
// /api/setup can be retried only when no accounts or local users remain. Used
|
||||
// when post-user steps (account or PAT creation) fail and the caller wants a
|
||||
// clean slate.
|
||||
RollbackSetup(ctx context.Context, userID string) error
|
||||
|
||||
// GetVersionInfo returns version information for NetBird components.
|
||||
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
|
||||
}
|
||||
@@ -70,6 +78,7 @@ type instanceStore interface {
|
||||
|
||||
type embeddedIdP interface {
|
||||
CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error)
|
||||
}
|
||||
|
||||
@@ -187,6 +196,51 @@ func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, n
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// RollbackSetup undoes a successful CreateOwnerUser: deletes the user from the
|
||||
// embedded IDP and reloads setupRequired from persistent state.
|
||||
func (m *DefaultManager) RollbackSetup(ctx context.Context, userID string) error {
|
||||
if m.embeddedIdpManager == nil {
|
||||
return errors.New("embedded IDP is not enabled")
|
||||
}
|
||||
|
||||
var deleteErr error
|
||||
if err := m.embeddedIdpManager.DeleteUser(ctx, userID); err != nil {
|
||||
if isNotFoundError(err) {
|
||||
log.WithContext(ctx).Debugf("setup rollback user %s already deleted", userID)
|
||||
} else {
|
||||
deleteErr = fmt.Errorf("failed to delete user from embedded IdP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.loadSetupRequired(ctx); err != nil {
|
||||
reloadErr := fmt.Errorf("failed to reload setup state after rollback: %w", err)
|
||||
if deleteErr != nil {
|
||||
return errors.Join(deleteErr, reloadErr)
|
||||
}
|
||||
return reloadErr
|
||||
}
|
||||
|
||||
if deleteErr != nil {
|
||||
return deleteErr
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("rolled back setup for user %s", userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return true
|
||||
}
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return s.Type() == status.NotFound
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error {
|
||||
numAccounts, err := m.store.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,16 +10,19 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type mockIdP struct {
|
||||
mu sync.Mutex
|
||||
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
users map[string][]*idp.UserData
|
||||
mu sync.Mutex
|
||||
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
deleteUserFunc func(ctx context.Context, userID string) error
|
||||
users map[string][]*idp.UserData
|
||||
getAllAccountsErr error
|
||||
}
|
||||
|
||||
@@ -30,6 +33,13 @@ func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, n
|
||||
return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil
|
||||
}
|
||||
|
||||
func (m *mockIdP) DeleteUser(ctx context.Context, userID string) error {
|
||||
if m.deleteUserFunc != nil {
|
||||
return m.deleteUserFunc(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -223,6 +233,77 @@ func TestIsSetupRequired_ReturnsFlag(t *testing.T) {
|
||||
assert.False(t, required)
|
||||
}
|
||||
|
||||
func TestRollbackSetup_UserAlreadyDeletedIsSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "management status not found",
|
||||
err: status.NewUserNotFoundError("owner-id"),
|
||||
},
|
||||
{
|
||||
name: "dex storage not found",
|
||||
err: fmt.Errorf("failed to get user for deletion: %w", storage.ErrNotFound),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
idpMock := &mockIdP{
|
||||
deleteUserFunc: func(_ context.Context, userID string) error {
|
||||
assert.Equal(t, "owner-id", userID)
|
||||
return tt.err
|
||||
},
|
||||
}
|
||||
mgr := newTestManager(idpMock, &mockStore{})
|
||||
mgr.setupRequired = false
|
||||
|
||||
err := mgr.RollbackSetup(context.Background(), "owner-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
required, err := mgr.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, required, "setup should be required when no accounts or local users remain")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollbackSetup_RecomputesSetupStateWhenAccountStillExists(t *testing.T) {
|
||||
idpMock := &mockIdP{
|
||||
deleteUserFunc: func(_ context.Context, _ string) error {
|
||||
return status.NewUserNotFoundError("owner-id")
|
||||
},
|
||||
}
|
||||
mgr := newTestManager(idpMock, &mockStore{accountsCount: 1})
|
||||
mgr.setupRequired = true
|
||||
|
||||
err := mgr.RollbackSetup(context.Background(), "owner-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
required, err := mgr.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.False(t, required, "setup should not be required while an account still exists")
|
||||
}
|
||||
|
||||
func TestRollbackSetup_ReturnsDeleteErrorButReloadsSetupState(t *testing.T) {
|
||||
idpMock := &mockIdP{
|
||||
deleteUserFunc: func(_ context.Context, _ string) error {
|
||||
return errors.New("idp unavailable")
|
||||
},
|
||||
}
|
||||
mgr := newTestManager(idpMock, &mockStore{})
|
||||
mgr.setupRequired = false
|
||||
|
||||
err := mgr.RollbackSetup(context.Background(), "owner-id")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "idp unavailable")
|
||||
|
||||
required, err := mgr.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, required, "setup state should be reloaded even when user deletion fails")
|
||||
}
|
||||
|
||||
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
|
||||
manager := &DefaultManager{setupRequired: true}
|
||||
|
||||
|
||||
216
management/server/instance/setup_service.go
Normal file
216
management/server/instance/setup_service.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
setupPATTokenName = "setup-token"
|
||||
|
||||
// SetupPATEnabledEnvKey enables setup-time Personal Access Token creation.
|
||||
SetupPATEnabledEnvKey = "NB_SETUP_PAT_ENABLED"
|
||||
|
||||
setupPATDefaultExpireDays = 1
|
||||
)
|
||||
|
||||
// SetupOptions controls optional work performed during initial instance setup.
|
||||
type SetupOptions struct {
|
||||
// CreatePAT requests creation of a setup Personal Access Token. It is honored
|
||||
// only when SetupPATEnabledEnvKey is set to "true".
|
||||
CreatePAT bool
|
||||
// PATExpireInDays defaults to 1 day when CreatePAT is requested and setup PAT
|
||||
// creation is enabled.
|
||||
PATExpireInDays *int
|
||||
}
|
||||
|
||||
// SetupResult contains resources created during initial instance setup.
|
||||
type SetupResult struct {
|
||||
User *idp.UserData
|
||||
PATPlainToken string
|
||||
}
|
||||
|
||||
// SetupService orchestrates the initial setup use case across the instance and
|
||||
// account bounded contexts and owns the compensation logic when a later step
|
||||
// fails.
|
||||
type SetupService struct {
|
||||
instanceManager Manager
|
||||
accountManager account.Manager
|
||||
setupPATEnabled bool
|
||||
}
|
||||
|
||||
// NewSetupService creates a setup use-case service.
|
||||
func NewSetupService(instanceManager Manager, accountManager account.Manager) *SetupService {
|
||||
return &SetupService{
|
||||
instanceManager: instanceManager,
|
||||
accountManager: accountManager,
|
||||
setupPATEnabled: os.Getenv(SetupPATEnabledEnvKey) == "true",
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSetupOptions(opts SetupOptions, setupPATEnabled bool) (SetupOptions, error) {
|
||||
if !opts.CreatePAT {
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
if !setupPATEnabled {
|
||||
opts.CreatePAT = false
|
||||
opts.PATExpireInDays = nil
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
if opts.PATExpireInDays == nil {
|
||||
defaultExpireInDays := setupPATDefaultExpireDays
|
||||
opts.PATExpireInDays = &defaultExpireInDays
|
||||
}
|
||||
|
||||
if *opts.PATExpireInDays < account.PATMinExpireDays || *opts.PATExpireInDays > account.PATMaxExpireDays {
|
||||
return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays)
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// SetupOwner creates the initial owner user and, when requested and enabled by
|
||||
// SetupPATEnabledEnvKey, provisions the account and a setup Personal Access
|
||||
// Token. If account or PAT provisioning fails, created resources are rolled
|
||||
// back so setup can be retried. If account rollback fails, user rollback is
|
||||
// skipped to avoid leaving an account without its owner user.
|
||||
func (m *SetupService) SetupOwner(ctx context.Context, email, password, name string, opts SetupOptions) (*SetupResult, error) {
|
||||
opts, err := normalizeSetupOptions(opts, m.setupPATEnabled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.CreatePAT && m.accountManager == nil {
|
||||
return nil, fmt.Errorf("account manager is required to create setup PAT")
|
||||
}
|
||||
|
||||
userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SetupResult{User: userData}
|
||||
if !opts.CreatePAT {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: userData.ID,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
}
|
||||
|
||||
accountID, err := m.accountManager.GetAccountIDByUserID(ctx, userAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create account for setup user: %w", err)
|
||||
if rollbackErr := m.rollbackSetup(ctx, userData.ID, "account provisioning failed", err, ""); rollbackErr != nil {
|
||||
return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pat, err := m.accountManager.CreatePAT(ctx, accountID, userData.ID, userData.ID, setupPATTokenName, *opts.PATExpireInDays)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create setup PAT: %w", err)
|
||||
if rollbackErr := m.rollbackSetup(ctx, userData.ID, "setup PAT provisioning failed", err, accountID); rollbackErr != nil {
|
||||
return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result.PATPlainToken = pat.PlainToken
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) error {
|
||||
if accountID == "" {
|
||||
resolvedAccountID, err := m.lookupSetupAccountIDForRollback(ctx, userID)
|
||||
if err != nil {
|
||||
rollbackErr := fmt.Errorf("resolve setup account for rollback: %w", err)
|
||||
log.WithContext(ctx).Errorf("failed to resolve setup account for user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr)
|
||||
return rollbackErr
|
||||
}
|
||||
accountID = resolvedAccountID
|
||||
}
|
||||
|
||||
if accountID != "" {
|
||||
if err := m.rollbackSetupAccount(ctx, accountID); err != nil {
|
||||
rollbackErr := fmt.Errorf("roll back setup account %s: %w", accountID, err)
|
||||
log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, rollbackErr)
|
||||
return rollbackErr
|
||||
}
|
||||
log.WithContext(ctx).Warnf("rolled back setup account %s for user %s after %s: %v", accountID, userID, reason, origErr)
|
||||
}
|
||||
|
||||
if err := m.instanceManager.RollbackSetup(ctx, userID); err != nil {
|
||||
rollbackErr := fmt.Errorf("roll back setup user %s: %w", userID, err)
|
||||
log.WithContext(ctx).Errorf("failed to roll back setup user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr)
|
||||
return rollbackErr
|
||||
}
|
||||
log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SetupService) lookupSetupAccountIDForRollback(ctx context.Context, userID string) (string, error) {
|
||||
if m.accountManager == nil {
|
||||
return "", fmt.Errorf("account manager is required to resolve setup account")
|
||||
}
|
||||
|
||||
accountStore := m.accountManager.GetStore()
|
||||
if accountStore == nil {
|
||||
return "", fmt.Errorf("account store is unavailable")
|
||||
}
|
||||
|
||||
accountID, err := accountStore.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("get setup account ID for rollback: %w", err)
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
// rollbackSetupAccount removes only the setup-created account data from the
|
||||
// store. It intentionally avoids accountManager.DeleteAccount because the normal
|
||||
// account deletion path also deletes users from the IdP; embedded IdP cleanup is
|
||||
// owned by instanceManager.RollbackSetup.
|
||||
func (m *SetupService) rollbackSetupAccount(ctx context.Context, accountID string) error {
|
||||
if m.accountManager == nil {
|
||||
return fmt.Errorf("account manager is required to roll back setup account")
|
||||
}
|
||||
|
||||
accountStore := m.accountManager.GetStore()
|
||||
if accountStore == nil {
|
||||
return fmt.Errorf("account store is unavailable")
|
||||
}
|
||||
|
||||
account, err := accountStore.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("get setup account for rollback: %w", err)
|
||||
}
|
||||
|
||||
if err := accountStore.DeleteAccount(ctx, account); err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("delete setup account for rollback: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
318
management/server/instance/setup_service_test.go
Normal file
318
management/server/instance/setup_service_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbstore "github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type setupInstanceManagerMock struct {
|
||||
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
rollbackSetupFn func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
func (m *setupInstanceManagerMock) IsSetupRequired(context.Context) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *setupInstanceManagerMock) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
if m.createOwnerUserFn != nil {
|
||||
return m.createOwnerUserFn(ctx, email, password, name)
|
||||
}
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
}
|
||||
|
||||
func (m *setupInstanceManagerMock) RollbackSetup(ctx context.Context, userID string) error {
|
||||
if m.rollbackSetupFn != nil {
|
||||
return m.rollbackSetupFn(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *setupInstanceManagerMock) GetVersionInfo(context.Context) (*VersionInfo, error) {
|
||||
return &VersionInfo{}, nil
|
||||
}
|
||||
|
||||
var _ Manager = (*setupInstanceManagerMock)(nil)
|
||||
|
||||
func intPtr(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestSetupOwner_PATFeatureDisabled_IgnoresCreatePAT(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "false")
|
||||
|
||||
createCalls := 0
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
|
||||
createCalls++
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "owner-id", result.User.ID)
|
||||
assert.Empty(t, result.PATPlainToken)
|
||||
assert.Equal(t, 1, createCalls)
|
||||
}
|
||||
|
||||
func TestSetupOwner_PATFeatureEnabled_MissingExpireDefaultsToOneDay(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
createCalled := false
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
|
||||
createCalled = true
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
assert.Equal(t, "owner-id", userAuth.UserId)
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
assert.Equal(t, "acc-1", accountID)
|
||||
assert.Equal(t, "owner-id", initiatorUserID)
|
||||
assert.Equal(t, "owner-id", targetUserID)
|
||||
assert.Equal(t, setupPATTokenName, tokenName)
|
||||
assert.Equal(t, setupPATDefaultExpireDays, expiresIn)
|
||||
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.True(t, createCalled)
|
||||
assert.Equal(t, "nbp_plain", result.PATPlainToken)
|
||||
}
|
||||
|
||||
func TestSetupOwner_PATFeatureEnabled_MissingAccountManagerFailsBeforeCreateUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
createCalled := false
|
||||
rollbackCalled := false
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
|
||||
createCalled = true
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
rollbackSetupFn: func(_ context.Context, _ string) error {
|
||||
rollbackCalled = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "account manager is required")
|
||||
assert.False(t, createCalled)
|
||||
assert.False(t, rollbackCalled)
|
||||
}
|
||||
|
||||
func TestSetupOwner_AccountProvisioningFails_RollsBackSideEffectAccountAndUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
account := &types.Account{Id: "acc-1"}
|
||||
accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("acc-1", nil)
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
|
||||
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
|
||||
|
||||
rolledBackFor := ""
|
||||
rollbackCalls := 0
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rollbackCalls++
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
assert.Equal(t, "owner-id", userAuth.UserId)
|
||||
return "", errors.New("metadata update failed")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
PATExpireInDays: intPtr(30),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "create account for setup user")
|
||||
assert.Equal(t, "owner-id", rolledBackFor)
|
||||
assert.Equal(t, 1, rollbackCalls)
|
||||
}
|
||||
|
||||
func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
account := &types.Account{Id: "acc-1"}
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
|
||||
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
|
||||
|
||||
rollbackCalls := 0
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rollbackCalls++
|
||||
assert.Equal(t, "owner-id", userID)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
assert.Equal(t, "owner-id", userAuth.UserId)
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
assert.Equal(t, "acc-1", accountID)
|
||||
assert.Equal(t, "owner-id", initiatorUserID)
|
||||
assert.Equal(t, "owner-id", targetUserID)
|
||||
assert.Equal(t, setupPATTokenName, tokenName)
|
||||
assert.Equal(t, 30, expiresIn)
|
||||
return nil, status.Errorf(status.Internal, "token store unavailable")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
PATExpireInDays: intPtr(30),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "create setup PAT")
|
||||
assert.Equal(t, 1, rollbackCalls)
|
||||
}
|
||||
|
||||
func TestSetupOwner_CreatePATFails_AccountAlreadyGoneStillRollsBackUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(nil, status.NewAccountNotFoundError("acc-1"))
|
||||
|
||||
rolledBackFor := ""
|
||||
rollbackCalls := 0
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rollbackCalls++
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
return nil, errors.New("token failure")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
PATExpireInDays: intPtr(30),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "create setup PAT")
|
||||
assert.Equal(t, "owner-id", rolledBackFor)
|
||||
assert.Equal(t, 1, rollbackCalls)
|
||||
}
|
||||
|
||||
func TestSetupOwner_CreatePATFails_AccountRollbackFailureStopsBeforeUserRollback(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
account := &types.Account{Id: "acc-1"}
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
|
||||
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(errors.New("delete failed"))
|
||||
|
||||
rollbackCalls := 0
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rollbackCalls++
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
return nil, errors.New("token failure")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
PATExpireInDays: intPtr(30),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "create setup PAT")
|
||||
assert.Contains(t, err.Error(), "failed to roll back setup resources")
|
||||
assert.Equal(t, 0, rollbackCalls)
|
||||
}
|
||||
@@ -128,8 +128,8 @@ type MockAccountManager struct {
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
|
||||
GetIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error)
|
||||
@@ -200,15 +200,15 @@ func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userI
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
if am.UpdateAccountPeersFunc != nil {
|
||||
am.UpdateAccountPeersFunc(ctx, accountID)
|
||||
am.UpdateAccountPeersFunc(ctx, accountID, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
if am.BufferUpdateAccountPeersFunc != nil {
|
||||
am.BufferUpdateAccountPeersFunc(ctx, accountID)
|
||||
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
@@ -133,7 +133,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -176,7 +176,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -177,7 +178,7 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate})
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
@@ -270,7 +270,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
}
|
||||
}()
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate})
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
@@ -352,7 +352,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -119,7 +120,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate})
|
||||
|
||||
return router, nil
|
||||
}
|
||||
@@ -183,7 +184,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate})
|
||||
|
||||
return router, nil
|
||||
}
|
||||
@@ -217,7 +218,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
|
||||
event()
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -818,6 +818,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
|
||||
}
|
||||
if newPeer.Status != nil && newPeer.Status.RequiresApproval {
|
||||
opEvent.Meta["pending_approval"] = true
|
||||
}
|
||||
|
||||
if !temporary {
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
@@ -1221,12 +1224,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
|
||||
// UpdateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAccountPeer updates a single peer that belongs to an account.
|
||||
|
||||
@@ -975,7 +975,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.UpdateAccountPeers(ctx, account.Id)
|
||||
manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{})
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
@@ -1033,7 +1033,7 @@ func testUpdateAccountPeers(t *testing.T) {
|
||||
peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
manager.UpdateAccountPeers(ctx, account.Id)
|
||||
manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{})
|
||||
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
|
||||
@@ -96,7 +96,11 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
policyOp := types.UpdateOperationCreate
|
||||
if isUpdate {
|
||||
policyOp = types.UpdateOperationUpdate
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp})
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
@@ -139,7 +143,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -17,19 +17,48 @@ type PeerNetworkRangeCheck struct {
|
||||
|
||||
var _ Check = (*PeerNetworkRangeCheck)(nil)
|
||||
|
||||
// prefixContains reports whether outer fully contains inner (equal counts as contained).
|
||||
// Requires the same address family, that outer is no more specific than inner (its
|
||||
// netmask is shorter or equal), and that inner's network address falls inside outer.
|
||||
// This is stricter than netip.Prefix.Contains(Addr) — a peer's /24 NIC will not match a
|
||||
// configured /32 rule, since the rule covers a single host but the NIC describes a whole
|
||||
// subnet whose host bits are unknown.
|
||||
func prefixContains(outer, inner netip.Prefix) bool {
|
||||
outer = outer.Masked()
|
||||
inner = inner.Masked()
|
||||
return outer.Bits() <= inner.Bits() &&
|
||||
outer.Addr().BitLen() == inner.Addr().BitLen() && // same family
|
||||
outer.Contains(inner.Addr())
|
||||
}
|
||||
|
||||
// Check evaluates configured ranges against the peer's local network interface prefixes
|
||||
// and its public connection IP (as a /32 or /128). A configured range matches when it
|
||||
// fully contains one of those prefixes, so operators can target both private subnets
|
||||
// and public CIDRs (e.g. 1.0.0.0/24, 2.2.2.2/32). Including the connection IP is what
|
||||
// lets a public-range posture check work — peer.Meta.NetworkAddresses only carries
|
||||
// local NIC addresses.
|
||||
func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
|
||||
if len(peer.Meta.NetworkAddresses) == 0 {
|
||||
peerPrefixes := make([]netip.Prefix, 0, len(peer.Meta.NetworkAddresses)+1)
|
||||
for _, peerNetAddr := range peer.Meta.NetworkAddresses {
|
||||
peerPrefixes = append(peerPrefixes, peerNetAddr.NetIP)
|
||||
}
|
||||
// Unmap collapses 4-in-6 forms (::ffff:a.b.c.d) so an IPv4 range matches.
|
||||
if connIP := peer.Location.ConnectionIP; len(connIP) > 0 {
|
||||
if addr, ok := netip.AddrFromSlice(connIP); ok {
|
||||
addr = addr.Unmap()
|
||||
peerPrefixes = append(peerPrefixes, netip.PrefixFrom(addr, addr.BitLen()))
|
||||
}
|
||||
}
|
||||
|
||||
if len(peerPrefixes) == 0 {
|
||||
return false, fmt.Errorf("peer's does not contain peer network range addresses")
|
||||
}
|
||||
|
||||
maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges))
|
||||
for _, prefix := range p.Ranges {
|
||||
maskedPrefixes = append(maskedPrefixes, prefix.Masked())
|
||||
}
|
||||
|
||||
for _, peerNetAddr := range peer.Meta.NetworkAddresses {
|
||||
peerMaskedPrefix := peerNetAddr.NetIP.Masked()
|
||||
if slices.Contains(maskedPrefixes, peerMaskedPrefix) {
|
||||
for _, peerPrefix := range peerPrefixes {
|
||||
for _, rangePrefix := range p.Ranges {
|
||||
if !prefixContains(rangePrefix, peerPrefix) {
|
||||
continue
|
||||
}
|
||||
switch p.Action {
|
||||
case CheckActionDeny:
|
||||
return false, nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -134,6 +135,205 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) {
|
||||
wantErr: true,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "Peer connection IP matches the denied /32",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionDeny,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("109.41.115.194/32"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{
|
||||
{NetIP: netip.MustParsePrefix("192.168.0.123/24")},
|
||||
},
|
||||
},
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "Peer connection IP does not match the denied /32",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionDeny,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("109.41.115.194/32"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{
|
||||
{NetIP: netip.MustParsePrefix("192.168.0.123/24")},
|
||||
},
|
||||
},
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "Peer connection IP matches the allowed /32 with no NetworkAddresses",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("109.41.115.194/32"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 connection IP matches the denied /128",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionDeny,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::1/128"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::1")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 connection IP does not match the denied /128",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionDeny,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::1/128"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::2")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4-mapped IPv6 connection IP matches IPv4 /32",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionDeny,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("109.41.115.194/32"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("::ffff:109.41.115.194")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "Connection IP falls inside an allowed /24 range",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("1.0.0.0/24"),
|
||||
netip.MustParsePrefix("2.2.2.2/32"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.55")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "Connection IP falls inside an allowed /23 range",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("3.0.0.0/23"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("3.0.1.200")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "Connection IP outside the allowed /24 range",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("1.0.0.0/24"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.1.5")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "Connection IP inside a denied /24 range",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionDeny,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("1.0.0.0/24"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.7")},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "Local NIC /24 does not match a /32 rule even if host bit lines up",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.5/32"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{
|
||||
{NetIP: netip.MustParsePrefix("192.168.0.5/24")},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "Local NIC address inside an allowed /16 range",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{
|
||||
{NetIP: netip.MustParsePrefix("192.168.5.7/24")},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NetworkAddresses and empty ConnectionIP still errors",
|
||||
check: PeerNetworkRangeCheck{
|
||||
Action: CheckActionDeny,
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("109.41.115.194/32"),
|
||||
},
|
||||
},
|
||||
peer: nbpeer.Peer{},
|
||||
wantErr: true,
|
||||
isValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -76,7 +77,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
postureOp := types.UpdateOperationCreate
|
||||
if isUpdate {
|
||||
postureOp = types.UpdateOperationUpdate
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp})
|
||||
}
|
||||
|
||||
return postureChecks, nil
|
||||
|
||||
@@ -191,7 +191,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
return newRoute, nil
|
||||
@@ -245,7 +245,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
type AccountManagerMetrics struct {
|
||||
ctx context.Context
|
||||
updateAccountPeersDurationMs metric.Float64Histogram
|
||||
updateAccountPeersCounter metric.Int64Counter
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
peerMetaUpdateCount metric.Int64Counter
|
||||
@@ -48,6 +50,13 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updateAccountPeersCounter, err := meter.Int64Counter("management.account.update.account.peers.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of account peers updates triggered, labeled by resource and operation"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of updates with new meta data from the peers"))
|
||||
@@ -59,6 +68,7 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
|
||||
updateAccountPeersCounter: updateAccountPeersCounter,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||
}, nil
|
||||
@@ -80,6 +90,16 @@ func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
|
||||
metrics.networkMapObjectCount.Record(metrics.ctx, count)
|
||||
}
|
||||
|
||||
// CountUpdateAccountPeersTriggered increments the counter for account peers updates with resource and operation labels.
|
||||
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource, operation string) {
|
||||
metrics.updateAccountPeersCounter.Add(metrics.ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("resource", resource),
|
||||
attribute.String("operation", operation),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// CountPeerMetUpdate counts the number of peer meta updates
|
||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||
|
||||
37
management/server/types/update_reason.go
Normal file
37
management/server/types/update_reason.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package types
|
||||
|
||||
// UpdateReason describes why an account peers update was triggered.
|
||||
type UpdateReason struct {
|
||||
Resource UpdateResource
|
||||
Operation UpdateOperation
|
||||
}
|
||||
|
||||
// UpdateResource represents the kind of resource that triggered an account peers update.
|
||||
type UpdateResource string
|
||||
|
||||
const (
|
||||
UpdateResourceAccountSettings UpdateResource = "account_settings"
|
||||
UpdateResourceDNSSettings UpdateResource = "dns_settings"
|
||||
UpdateResourceGroup UpdateResource = "group"
|
||||
UpdateResourceNameServerGroup UpdateResource = "nameserver_group"
|
||||
UpdateResourcePolicy UpdateResource = "policy"
|
||||
UpdateResourcePostureCheck UpdateResource = "posture_check"
|
||||
UpdateResourceRoute UpdateResource = "route"
|
||||
UpdateResourceUser UpdateResource = "user"
|
||||
UpdateResourcePeer UpdateResource = "peer"
|
||||
UpdateResourceNetwork UpdateResource = "network"
|
||||
UpdateResourceNetworkResource UpdateResource = "network_resource"
|
||||
UpdateResourceNetworkRouter UpdateResource = "network_router"
|
||||
UpdateResourceService UpdateResource = "service"
|
||||
UpdateResourceZone UpdateResource = "zone"
|
||||
UpdateResourceZoneRecord UpdateResource = "zone_record"
|
||||
)
|
||||
|
||||
// UpdateOperation represents the kind of change that triggered the update.
|
||||
type UpdateOperation string
|
||||
|
||||
const (
|
||||
UpdateOperationCreate UpdateOperation = "create"
|
||||
UpdateOperationUpdate UpdateOperation = "update"
|
||||
UpdateOperationDelete UpdateOperation = "delete"
|
||||
)
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -395,8 +396,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
|
||||
}
|
||||
|
||||
if expiresIn < 1 || expiresIn > 365 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
|
||||
if expiresIn < account.PATMinExpireDays || expiresIn > account.PATMaxExpireDays {
|
||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays)
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create)
|
||||
@@ -675,7 +676,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return updatedUsersInfo, globalErr
|
||||
|
||||
@@ -2,7 +2,12 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -99,6 +104,27 @@ var debugStopCmd = &cobra.Command{
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugCaptureCmd = &cobra.Command{
|
||||
Use: "capture <account-id> [filter expression]",
|
||||
Short: "Capture packets on a client's WireGuard interface",
|
||||
Long: `Captures decrypted packets flowing through a client's WireGuard interface.
|
||||
|
||||
Default output is human-readable text. Use --pcap or --output for pcap binary.
|
||||
Filter arguments after the account ID use BPF-like syntax.
|
||||
|
||||
Examples:
|
||||
netbird-proxy debug capture <account-id>
|
||||
netbird-proxy debug capture <account-id> --duration 1m host 10.0.0.1
|
||||
netbird-proxy debug capture <account-id> host 10.0.0.1 and tcp port 443
|
||||
netbird-proxy debug capture <account-id> not port 22
|
||||
netbird-proxy debug capture <account-id> -o capture.pcap
|
||||
netbird-proxy debug capture <account-id> --pcap | tcpdump -r - -n
|
||||
netbird-proxy debug capture <account-id> --pcap | tshark -r -`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: runDebugCapture,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address")
|
||||
debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format")
|
||||
@@ -110,6 +136,12 @@ func init() {
|
||||
|
||||
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)")
|
||||
|
||||
debugCaptureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = server default)")
|
||||
debugCaptureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
|
||||
debugCaptureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length (text mode)")
|
||||
debugCaptureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (text mode)")
|
||||
debugCaptureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
|
||||
|
||||
debugCmd.AddCommand(debugHealthCmd)
|
||||
debugCmd.AddCommand(debugClientsCmd)
|
||||
debugCmd.AddCommand(debugStatusCmd)
|
||||
@@ -119,6 +151,7 @@ func init() {
|
||||
debugCmd.AddCommand(debugLogCmd)
|
||||
debugCmd.AddCommand(debugStartCmd)
|
||||
debugCmd.AddCommand(debugStopCmd)
|
||||
debugCmd.AddCommand(debugCaptureCmd)
|
||||
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
}
|
||||
@@ -171,3 +204,84 @@ func runDebugStart(cmd *cobra.Command, args []string) error {
|
||||
func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugCapture(cmd *cobra.Command, args []string) error {
|
||||
duration, _ := cmd.Flags().GetDuration("duration")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||
ascii, _ := cmd.Flags().GetBool("ascii")
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
|
||||
// Default to text. Use pcap when --pcap is set or --output is given.
|
||||
wantText := !forcePcap && outPath == ""
|
||||
|
||||
var filterExpr string
|
||||
if len(args) > 1 {
|
||||
filterExpr = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
out, cleanup, err := captureOutputWriter(cmd, outPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if wantText {
|
||||
cmd.PrintErrln("Capturing packets... Press Ctrl+C to stop.")
|
||||
} else {
|
||||
cmd.PrintErrln("Capturing packets (pcap)... Press Ctrl+C to stop.")
|
||||
}
|
||||
|
||||
var durationStr string
|
||||
if duration > 0 {
|
||||
durationStr = duration.String()
|
||||
}
|
||||
|
||||
err = getDebugClient(cmd).Capture(ctx, debug.CaptureOptions{
|
||||
AccountID: args[0],
|
||||
Duration: durationStr,
|
||||
FilterExpr: filterExpr,
|
||||
Text: wantText,
|
||||
Verbose: verbose,
|
||||
ASCII: ascii,
|
||||
Output: out,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PrintErrln("\nCapture finished.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// captureOutputWriter returns the writer and cleanup function for capture output.
|
||||
func captureOutputWriter(cmd *cobra.Command, outPath string) (out *os.File, cleanup func(), err error) {
|
||||
if outPath != "" {
|
||||
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create output file: %w", err)
|
||||
}
|
||||
tmpPath := f.Name()
|
||||
return f, func() {
|
||||
if err := f.Close(); err != nil {
|
||||
cmd.PrintErrf("close output file: %v\n", err)
|
||||
}
|
||||
if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 {
|
||||
if err := os.Rename(tmpPath, outPath); err != nil {
|
||||
cmd.PrintErrf("rename output file: %v\n", err)
|
||||
} else {
|
||||
cmd.PrintErrf("Wrote %s\n", outPath)
|
||||
}
|
||||
} else {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
return os.Stdout, func() {
|
||||
// no cleanup needed for stdout
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -310,6 +310,76 @@ func (c *Client) printError(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
// CaptureOptions configures a capture request.
|
||||
type CaptureOptions struct {
|
||||
AccountID string
|
||||
Duration string
|
||||
FilterExpr string
|
||||
Text bool
|
||||
Verbose bool
|
||||
ASCII bool
|
||||
Output io.Writer
|
||||
}
|
||||
|
||||
// Capture streams a packet capture from the debug endpoint. The response body
|
||||
// (pcap or text) is written directly to opts.Output until the server closes the
|
||||
// connection or the context is cancelled.
|
||||
func (c *Client) Capture(ctx context.Context, opts CaptureOptions) error {
|
||||
if opts.AccountID == "" {
|
||||
return fmt.Errorf("account ID is required")
|
||||
}
|
||||
if opts.Output == nil {
|
||||
return fmt.Errorf("output writer is required")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
if opts.Duration != "" {
|
||||
params.Set("duration", opts.Duration)
|
||||
}
|
||||
if opts.FilterExpr != "" {
|
||||
params.Set("filter", opts.FilterExpr)
|
||||
}
|
||||
if opts.Text {
|
||||
params.Set("format", "text")
|
||||
}
|
||||
if opts.Verbose {
|
||||
params.Set("verbose", "true")
|
||||
}
|
||||
if opts.ASCII {
|
||||
params.Set("ascii", "true")
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/capture", url.PathEscape(opts.AccountID))
|
||||
if len(params) > 0 {
|
||||
path += "?" + params.Encode()
|
||||
}
|
||||
|
||||
fullURL := c.baseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
// Use a separate client without timeout since captures stream for their full duration.
|
||||
httpClient := &http.Client{}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
_, err = io.Copy(opts.Output, resp.Body)
|
||||
if err != nil && ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
|
||||
data, raw, err := c.fetch(ctx, path)
|
||||
if err != nil {
|
||||
|
||||
@@ -174,6 +174,8 @@ func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, pat
|
||||
h.handleClientStart(w, r, accountID)
|
||||
case "stop":
|
||||
h.handleClientStop(w, r, accountID)
|
||||
case "capture":
|
||||
h.handleCapture(w, r, accountID)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -632,6 +634,81 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
|
||||
})
|
||||
}
|
||||
|
||||
const maxCaptureDuration = 30 * time.Minute
|
||||
|
||||
// handleCapture streams a pcap or text packet capture for the given client.
|
||||
//
|
||||
// Query params:
|
||||
//
|
||||
// duration: capture duration (0 or absent = max, capped at 30m)
|
||||
// format: "text" for human-readable output (default: pcap)
|
||||
// filter: BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443")
|
||||
func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "client not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
duration := maxCaptureDuration
|
||||
if durationStr := r.URL.Query().Get("duration"); durationStr != "" {
|
||||
d, err := time.ParseDuration(durationStr)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid duration: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if d < 0 {
|
||||
http.Error(w, "duration must not be negative", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if d > 0 {
|
||||
duration = min(d, maxCaptureDuration)
|
||||
}
|
||||
}
|
||||
|
||||
filter := r.URL.Query().Get("filter")
|
||||
wantText := r.URL.Query().Get("format") == "text"
|
||||
verbose := r.URL.Query().Get("verbose") == "true"
|
||||
ascii := r.URL.Query().Get("ascii") == "true"
|
||||
|
||||
opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii}
|
||||
if wantText {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
opts.TextOutput = w
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/vnd.tcpdump.pcap")
|
||||
w.Header().Set("Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=capture-%s.pcap", accountID))
|
||||
opts.Output = w
|
||||
}
|
||||
|
||||
cs, err := client.StartCapture(opts)
|
||||
if err != nil {
|
||||
http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer cs.Stop()
|
||||
|
||||
// Flush headers after setup succeeds so errors above can still set status codes.
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
timer := time.NewTimer(duration)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
cs.Stop()
|
||||
|
||||
stats := cs.Stats()
|
||||
h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped",
|
||||
accountID, stats.Packets, stats.Bytes, stats.Dropped)
|
||||
}
|
||||
|
||||
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) {
|
||||
if !wantJSON {
|
||||
http.Redirect(w, r, "/debug", http.StatusSeeOther)
|
||||
|
||||
@@ -246,27 +246,23 @@ func (c *GrpcClient) handleJobStream(
|
||||
for {
|
||||
jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
log.Debugf("job stream context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
}
|
||||
if s, ok := gstatus.FromError(err); ok {
|
||||
switch s.Code() {
|
||||
case codes.PermissionDenied:
|
||||
c.notifyDisconnected(err)
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
case codes.Canceled:
|
||||
log.Debugf("job stream context has been canceled, this usually indicates shutdown")
|
||||
return err
|
||||
case codes.Unimplemented:
|
||||
log.Warn("Job feature is not supported by the current management server version. " +
|
||||
"Please update the management service to use this feature.")
|
||||
return nil
|
||||
default:
|
||||
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// non-gRPC error
|
||||
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if jobReq == nil || len(jobReq.ID) == 0 {
|
||||
@@ -381,22 +377,15 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
|
||||
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
c.notifyDisconnected(err)
|
||||
if s, ok := gstatus.FromError(err); ok {
|
||||
switch s.Code() {
|
||||
case codes.PermissionDenied:
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
case codes.Canceled:
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
default:
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// non-gRPC error
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
if ctx.Err() != nil {
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
}
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
}
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1687,15 +1687,18 @@ components:
|
||||
- locations
|
||||
- action
|
||||
PeerNetworkRangeCheck:
|
||||
description: Posture check for allow or deny access based on peer local network addresses
|
||||
description: |
|
||||
Posture check for allow or deny access based on the peer's IP addresses. A range matches when it
|
||||
contains any of the peer's local network interface IPs or its public connection (NAT egress) IP,
|
||||
so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128.
|
||||
type: object
|
||||
properties:
|
||||
ranges:
|
||||
description: List of peer network ranges in CIDR notation
|
||||
description: List of network ranges in CIDR notation, matched against the peer's local interface IPs and its public connection IP
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: [ "192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56" ]
|
||||
example: [ "192.168.1.0/24", "10.0.0.0/8", "1.0.0.0/24", "2.2.2.2/32", "2001:db8:1234:1a00::/56" ]
|
||||
action:
|
||||
description: Action to take upon policy match
|
||||
type: string
|
||||
@@ -3426,6 +3429,17 @@ components:
|
||||
description: Display name for the admin user (defaults to email if not provided)
|
||||
type: string
|
||||
example: Admin User
|
||||
create_pat:
|
||||
description: If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled.
|
||||
type: boolean
|
||||
example: true
|
||||
pat_expire_in:
|
||||
description: Expiration of the Personal Access Token in days. Applies only when create_pat is true and the server feature is enabled. Defaults to 1 day when omitted.
|
||||
type: integer
|
||||
minimum: 1
|
||||
maximum: 365
|
||||
default: 1
|
||||
example: 30
|
||||
required:
|
||||
- email
|
||||
- password
|
||||
@@ -3442,6 +3456,12 @@ components:
|
||||
description: Email address of the created user
|
||||
type: string
|
||||
example: admin@example.com
|
||||
personal_access_token:
|
||||
description: Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server.
|
||||
type: string
|
||||
format: password
|
||||
readOnly: true
|
||||
example: nbp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
required:
|
||||
- user_id
|
||||
- email
|
||||
@@ -4980,7 +5000,10 @@ paths:
|
||||
/api/setup:
|
||||
post:
|
||||
summary: Setup Instance
|
||||
description: Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled).
|
||||
description: |
|
||||
Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled).
|
||||
|
||||
When the management server is started with `NB_SETUP_PAT_ENABLED=true` and the request includes `create_pat: true`, the endpoint also provisions the NetBird account for the new owner user and returns the plain text Personal Access Token in `personal_access_token`. The optional `pat_expire_in` value applies only when `create_pat` is true and defaults to 1 day when omitted. If a post-user step fails, setup-created resources are rolled back when safe; if account cleanup fails, the owner user is left in place to avoid leaving an account without its admin user.
|
||||
tags: [ Instance ]
|
||||
security: [ ]
|
||||
requestBody:
|
||||
@@ -4993,6 +5016,12 @@ paths:
|
||||
responses:
|
||||
'200':
|
||||
description: Setup completed successfully
|
||||
headers:
|
||||
Cache-Control:
|
||||
description: Always set to no-store because the response may contain a one-time plain text Personal Access Token.
|
||||
schema:
|
||||
type: string
|
||||
example: no-store
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
||||
@@ -1626,7 +1626,7 @@ type Checks struct {
|
||||
// OsVersionCheck Posture check for the version of operating system
|
||||
OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"`
|
||||
|
||||
// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
|
||||
// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128.
|
||||
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"`
|
||||
|
||||
// ProcessCheck Posture Check for binaries exist and are running in the peer’s system
|
||||
@@ -3312,12 +3312,12 @@ type PeerMinimum struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
|
||||
// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128.
|
||||
type PeerNetworkRangeCheck struct {
|
||||
// Action Action to take upon policy match
|
||||
Action PeerNetworkRangeCheckAction `json:"action"`
|
||||
|
||||
// Ranges List of peer network ranges in CIDR notation
|
||||
// Ranges List of network ranges in CIDR notation, matched against the peer's local interface IPs and its public connection IP
|
||||
Ranges []string `json:"ranges"`
|
||||
}
|
||||
|
||||
@@ -4297,6 +4297,9 @@ type SetupKeyRequest struct {
|
||||
|
||||
// SetupRequest Request to set up the initial admin user
|
||||
type SetupRequest struct {
|
||||
// CreatePat If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled.
|
||||
CreatePat *bool `json:"create_pat,omitempty"`
|
||||
|
||||
// Email Email address for the admin user
|
||||
Email string `json:"email"`
|
||||
|
||||
@@ -4305,6 +4308,9 @@ type SetupRequest struct {
|
||||
|
||||
// Password Password for the admin user (minimum 8 characters)
|
||||
Password string `json:"password"`
|
||||
|
||||
// PatExpireIn Expiration of the Personal Access Token in days. Applies only when create_pat is true and the server feature is enabled. Defaults to 1 day when omitted.
|
||||
PatExpireIn *int `json:"pat_expire_in,omitempty"`
|
||||
}
|
||||
|
||||
// SetupResponse Response after successful instance setup
|
||||
@@ -4312,6 +4318,9 @@ type SetupResponse struct {
|
||||
// Email Email address of the created user
|
||||
Email string `json:"email"`
|
||||
|
||||
// PersonalAccessToken Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server.
|
||||
PersonalAccessToken *string `json:"personal_access_token,omitempty"`
|
||||
|
||||
// UserId The ID of the created user
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
@@ -2,8 +2,12 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -146,6 +150,7 @@ func (cc *connContainer) close() {
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
connectionURL string
|
||||
serverIP netip.Addr
|
||||
authTokenStore *auth.TokenStore
|
||||
hashedID messages.PeerID
|
||||
|
||||
@@ -170,13 +175,22 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
// is called.
|
||||
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
return NewClientWithServerIP(serverURL, netip.Addr{}, authTokenStore, peerID, mtu)
|
||||
}
|
||||
|
||||
// NewClientWithServerIP creates a new client for the relay server with a known server IP. serverIP, when valid, is
|
||||
// dialed directly first; the FQDN is only attempted if the IP-based dial fails. TLS verification still uses the
|
||||
// FQDN from serverURL via SNI.
|
||||
func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
hashedID := messages.HashID(peerID)
|
||||
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||
|
||||
c := &Client{
|
||||
log: relayLog,
|
||||
connectionURL: serverURL,
|
||||
serverIP: serverIP,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
mtu: mtu,
|
||||
@@ -304,6 +318,23 @@ func (c *Client) ServerInstanceURL() (string, error) {
|
||||
return c.instanceURL.String(), nil
|
||||
}
|
||||
|
||||
// ConnectedIP returns the IP address of the live relay-server connection,
|
||||
// extracted from the underlying socket's RemoteAddr. Zero value if not
|
||||
// connected or if the address is not an IP literal.
|
||||
func (c *Client) ConnectedIP() netip.Addr {
|
||||
c.mu.Lock()
|
||||
conn := c.relayConn
|
||||
c.mu.Unlock()
|
||||
if conn == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
addr := conn.RemoteAddr()
|
||||
if addr == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return extractIPLiteral(addr.String())
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func(string)) {
|
||||
c.listenerMutex.Lock()
|
||||
@@ -332,10 +363,23 @@ func (c *Client) Close() error {
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
dialers := c.getDialers()
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
conn, err := rd.Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial via FQDN: %w", err)
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
@@ -351,6 +395,52 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
return instanceURL, nil
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
}
|
||||
|
||||
c.log.Debugf("dialing via server IP %s (SNI=%s)", c.serverIP, serverName)
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
// substituteHost replaces the host portion of a rel/rels URL with ip,
|
||||
// preserving the scheme and port. Returns the rewritten URL and the
|
||||
// original host to use as the TLS ServerName, or empty if the original
|
||||
// host is itself an IP literal (SNI requires a DNS name).
|
||||
func substituteHost(serverURL string, ip netip.Addr) (string, string, error) {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("parse %q: %w", serverURL, err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", fmt.Errorf("invalid relay URL %q", serverURL)
|
||||
}
|
||||
if !ip.IsValid() {
|
||||
return "", "", errors.New("invalid server IP")
|
||||
}
|
||||
origHost := u.Hostname()
|
||||
if _, err := netip.ParseAddr(origHost); err == nil {
|
||||
origHost = ""
|
||||
}
|
||||
ip = ip.Unmap()
|
||||
newHost := ip.String()
|
||||
if ip.Is6() {
|
||||
newHost = "[" + newHost + "]"
|
||||
}
|
||||
if port := u.Port(); port != "" {
|
||||
u.Host = newHost + ":" + port
|
||||
} else {
|
||||
u.Host = newHost
|
||||
}
|
||||
return u.String(), origHost, nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
@@ -716,3 +806,21 @@ func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
|
||||
}
|
||||
c.stateSubscription.OnPeersWentOffline(peersID)
|
||||
}
|
||||
|
||||
// extractIPLiteral returns the IP from address forms produced by the relay
|
||||
// dialers (URL or host:port). Zero value if the host is not an IP.
|
||||
func extractIPLiteral(s string) netip.Addr {
|
||||
if u, err := url.Parse(s); err == nil && u.Host != "" {
|
||||
s = u.Host
|
||||
}
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
host = s
|
||||
}
|
||||
host = strings.Trim(host, "[]")
|
||||
ip, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
280
shared/relay/client/client_serverip_test.go
Normal file
280
shared/relay/client/client_serverip_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/allow"
|
||||
)
|
||||
|
||||
// TestClient_ServerIPRecoversFromUnresolvableFQDN verifies that when the
|
||||
// primary FQDN-based dial fails (unresolvable .invalid host), Connect
|
||||
// recovers via the server IP and SNI still uses the FQDN.
|
||||
func TestClient_ServerIPRecoversFromUnresolvableFQDN(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
listenAddr, port := freeAddr(t)
|
||||
srvCfg := server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: fmt.Sprintf("rel://test-unresolvable-host.invalid:%d", port),
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
}
|
||||
srv, err := server.NewServer(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("create server: %s", err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("shutdown server: %s", err)
|
||||
}
|
||||
})
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("server failed to start: %s", err)
|
||||
}
|
||||
|
||||
t.Run("no server IP, primary fails", func(t *testing.T) {
|
||||
c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-noip", iface.DefaultMTU)
|
||||
err := c.Connect(ctx)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
t.Fatalf("expected connect to fail without server IP, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server IP recovers", func(t *testing.T) {
|
||||
c := NewClientWithServerIP(srvCfg.ExposedAddress, netip.MustParseAddr("127.0.0.1"), hmacTokenStore, "alice-with-ip", iface.DefaultMTU)
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
t.Fatalf("connect with server IP: %s", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = c.Close() })
|
||||
|
||||
if !c.Ready() {
|
||||
t.Fatalf("client not ready after connect")
|
||||
}
|
||||
if got := c.ConnectedIP(); got.String() != "127.0.0.1" {
|
||||
t.Fatalf("ConnectedIP = %q, want 127.0.0.1", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestClient_ConnectedIPAfterFQDNDial verifies ConnectedIP returns the
|
||||
// resolved IP after a successful FQDN-based dial. The underlying socket's
|
||||
// RemoteAddr must be exposed through the dialer wrappers; if it returns
|
||||
// the dial-time URL instead, ConnectedIP returns empty and the dial
|
||||
// IP we advertise to peers is empty too.
|
||||
func TestClient_ConnectedIPAfterFQDNDial(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
listenAddr, port := freeAddr(t)
|
||||
srvCfg := server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: fmt.Sprintf("rel://localhost:%d", port),
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
}
|
||||
srv, err := server.NewServer(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() { _ = srv.Shutdown(context.Background()) })
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("server failed to start: %s", err)
|
||||
}
|
||||
|
||||
c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-fqdn", iface.DefaultMTU)
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
t.Fatalf("connect: %s", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = c.Close() })
|
||||
|
||||
got := c.ConnectedIP().String()
|
||||
if got != "127.0.0.1" && got != "::1" {
|
||||
t.Fatalf("ConnectedIP after FQDN dial = %q, want 127.0.0.1 or ::1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubstituteHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverURL string
|
||||
ip string
|
||||
wantURL string
|
||||
wantServerName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "rels with port",
|
||||
serverURL: "rels://relay.netbird.io:443",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5:443",
|
||||
wantServerName: "relay.netbird.io",
|
||||
},
|
||||
{
|
||||
name: "rel with port",
|
||||
serverURL: "rel://relay.example.com:80",
|
||||
ip: "192.0.2.1",
|
||||
wantURL: "rel://192.0.2.1:80",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "ipv6 server IP bracketed",
|
||||
serverURL: "rels://relay.example.com:443",
|
||||
ip: "2001:db8::1",
|
||||
wantURL: "rels://[2001:db8::1]:443",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "no port",
|
||||
serverURL: "rels://relay.example.com",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "ipv6 server with port returns empty SNI",
|
||||
serverURL: "rels://[2001:db8::5]:443",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5:443",
|
||||
wantServerName: "",
|
||||
},
|
||||
{
|
||||
name: "ipv4 server with port returns empty SNI",
|
||||
serverURL: "rels://10.0.0.5:443",
|
||||
ip: "10.0.0.6",
|
||||
wantURL: "rels://10.0.0.6:443",
|
||||
wantServerName: "",
|
||||
},
|
||||
{
|
||||
name: "ipv6 server IP no port",
|
||||
serverURL: "rels://relay.example.com",
|
||||
ip: "2001:db8::1",
|
||||
wantURL: "rels://[2001:db8::1]",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "missing scheme",
|
||||
serverURL: "relay.example.com:443",
|
||||
ip: "10.0.0.5",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
serverURL: "",
|
||||
ip: "10.0.0.5",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ip netip.Addr
|
||||
if tt.ip != "" {
|
||||
ip = netip.MustParseAddr(tt.ip)
|
||||
}
|
||||
gotURL, gotName, err := substituteHost(tt.serverURL, ip)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if gotURL != tt.wantURL {
|
||||
t.Errorf("URL = %q, want %q", gotURL, tt.wantURL)
|
||||
}
|
||||
if gotName != tt.wantServerName {
|
||||
t.Errorf("ServerName = %q, want %q", gotName, tt.wantServerName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_ConnectedIPEmptyWhenNotConnected(t *testing.T) {
|
||||
c := NewClient("rel://example.invalid:80", hmacTokenStore, "x", iface.DefaultMTU)
|
||||
if got := c.ConnectedIP(); got.IsValid() {
|
||||
t.Fatalf("ConnectedIP on disconnected client = %q, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
// staticAddr is a net.Addr that returns a fixed string. Used to verify
|
||||
// ConnectedIP parses RemoteAddr correctly.
|
||||
type staticAddr struct{ s string }
|
||||
|
||||
func (a staticAddr) Network() string { return "tcp" }
|
||||
func (a staticAddr) String() string { return a.s }
|
||||
|
||||
type stubConn struct {
|
||||
net.Conn
|
||||
remote net.Addr
|
||||
}
|
||||
|
||||
func (s stubConn) RemoteAddr() net.Addr { return s.remote }
|
||||
|
||||
func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want string
|
||||
}{
|
||||
{"hostport ipv4", "127.0.0.1:50301", "127.0.0.1"},
|
||||
{"hostport ipv6 bracketed", "[::1]:50301", "::1"},
|
||||
{"url with ipv4", "rel://127.0.0.1:50301", "127.0.0.1"},
|
||||
{"url with ipv6", "rels://[2001:db8::1]:443", "2001:db8::1"},
|
||||
{"fqdn url returns empty", "rel://relay.example.com:50301", ""},
|
||||
{"fqdn hostport returns empty", "relay.example.com:50301", ""},
|
||||
{"plain ipv4 no port", "10.0.0.1", "10.0.0.1"},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}}
|
||||
got := c.ConnectedIP()
|
||||
var gotStr string
|
||||
if got.IsValid() {
|
||||
gotStr = got.String()
|
||||
}
|
||||
if gotStr != tt.want {
|
||||
t.Errorf("ConnectedIP(%q) = %q, want %q", tt.s, gotStr, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// freeAddr returns a 127.0.0.1 address with an OS-assigned port. The
|
||||
// listener is closed before returning, so the port is briefly free for
|
||||
// the caller to bind. Avoids hardcoded ports that can collide.
|
||||
func freeAddr(t *testing.T) (string, int) {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("get free port: %s", err)
|
||||
}
|
||||
addr := l.Addr().(*net.TCPAddr)
|
||||
_ = l.Close()
|
||||
return addr.String(), addr.Port
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -32,11 +32,14 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
// Get the base TLS config
|
||||
tlsClientConfig := quictls.ClientQUICTLSConfig()
|
||||
|
||||
// Set ServerName to hostname if not an IP address
|
||||
host, _, splitErr := net.SplitHostPort(quicURL)
|
||||
if splitErr == nil && net.ParseIP(host) == nil {
|
||||
// It's a hostname, not an IP - modify directly
|
||||
tlsClientConfig.ServerName = host
|
||||
switch {
|
||||
case serverName != "" && net.ParseIP(serverName) == nil:
|
||||
tlsClientConfig.ServerName = serverName
|
||||
default:
|
||||
host, _, splitErr := net.SplitHostPort(quicURL)
|
||||
if splitErr == nil && net.ParseIP(host) == nil {
|
||||
tlsClientConfig.ServerName = host
|
||||
}
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
|
||||
@@ -14,7 +14,9 @@ const (
|
||||
)
|
||||
|
||||
type DialeFn interface {
|
||||
Dial(ctx context.Context, address string) (net.Conn, error)
|
||||
// Dial connects to address. serverName, when non-empty, overrides the TLS
|
||||
// ServerName used for SNI/cert validation. Empty means derive from address.
|
||||
Dial(ctx context.Context, address, serverName string) (net.Conn, error)
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
@@ -27,6 +29,7 @@ type dialResult struct {
|
||||
type RaceDial struct {
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
}
|
||||
@@ -40,6 +43,16 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerName sets a TLS SNI/cert validation override. Used when serverURL
|
||||
// contains an IP literal but the cert is issued for a different hostname.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
r.serverName = serverName
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
@@ -64,7 +77,7 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia
|
||||
defer cancel()
|
||||
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(ctx, r.serverURL)
|
||||
conn, err := dfn.Dial(ctx, r.serverURL, r.serverName)
|
||||
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user