mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-05 16:46:39 +00:00
Compare commits
6 Commits
feat/resel
...
fix/debug-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cb25de9ea | ||
|
|
97db824929 | ||
|
|
77a0992dc2 | ||
|
|
104990dfdd | ||
|
|
bde632c3b2 | ||
|
|
4268a5cfb7 |
@@ -58,6 +58,11 @@ linters:
|
||||
govet:
|
||||
enable:
|
||||
- nilness
|
||||
disable:
|
||||
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
|
||||
# directives but cannot perform the rewrite due to generic type
|
||||
# parameter inference limitations in the Go inliner.
|
||||
- inline
|
||||
enable-all: false
|
||||
revive:
|
||||
rules:
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
}
|
||||
@@ -256,7 +258,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
}
|
||||
|
||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
|
||||
|
||||
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -324,7 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
if err != nil {
|
||||
@@ -334,7 +336,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
@@ -348,6 +350,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
if showQR {
|
||||
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
|
||||
printQRCode(f, verificationURIComplete)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
|
||||
25
client/cmd/qr.go
Normal file
25
client/cmd/qr.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
)
|
||||
|
||||
// printQRCode prints a QR code for the given URL to the writer.
|
||||
// Called only when the user explicitly requests QR output via --qr.
|
||||
func printQRCode(w io.Writer, url string) {
|
||||
if url == "" {
|
||||
return
|
||||
}
|
||||
qrterminal.GenerateWithConfig(url, qrterminal.Config{
|
||||
Level: qrterminal.M,
|
||||
Writer: w,
|
||||
HalfBlocks: true,
|
||||
BlackChar: qrterminal.BLACK_BLACK,
|
||||
WhiteChar: qrterminal.WHITE_WHITE,
|
||||
BlackWhiteChar: qrterminal.BLACK_WHITE,
|
||||
WhiteBlackChar: qrterminal.WHITE_BLACK,
|
||||
QuietZone: qrterminal.QUIET_ZONE,
|
||||
})
|
||||
}
|
||||
26
client/cmd/qr_test.go
Normal file
26
client/cmd/qr_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrintQRCode_EmptyURL(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "")
|
||||
|
||||
if buf.Len() != 0 {
|
||||
t.Error("expected no output for empty URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintQRCode_WritesOutput(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
printQRCode(&buf, "https://example.com/auth")
|
||||
|
||||
if buf.Len() == 0 {
|
||||
t.Error("expected QR code output for non-empty URL")
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,9 @@ const (
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
|
||||
showQRFlag = "qr"
|
||||
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
|
||||
|
||||
profileNameFlag = "profile"
|
||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||
)
|
||||
@@ -48,6 +51,7 @@ var (
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
showQR bool
|
||||
profileName string
|
||||
configPath string
|
||||
|
||||
@@ -80,6 +84,7 @@ func init() {
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||
|
||||
|
||||
@@ -607,6 +607,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||
}
|
||||
if g.internalConfig.DisableSSHAuth != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
|
||||
}
|
||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
@@ -633,6 +639,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addProf() (err error) {
|
||||
|
||||
@@ -5,16 +5,21 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -471,8 +476,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
@@ -489,9 +494,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
assert.Contains(t, anonNftables, "chain input {")
|
||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||
}
|
||||
|
||||
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
|
||||
// profilemanager.Config is either rendered in the debug bundle or explicitly
|
||||
// excluded. When a new field is added to Config, this test fails until the
|
||||
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
|
||||
// the excluded set with a justification.
|
||||
func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
excluded := map[string]string{
|
||||
"PrivateKey": "sensitive: WireGuard private key",
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
aURL, _ := url.Parse("https://admin.example.com:443")
|
||||
bTrue := true
|
||||
iVal := 42
|
||||
cfg := &profilemanager.Config{
|
||||
PrivateKey: "priv",
|
||||
PreSharedKey: "psk",
|
||||
ManagementURL: mURL,
|
||||
AdminURL: aURL,
|
||||
WgIface: "wt0",
|
||||
WgPort: 51820,
|
||||
NetworkMonitor: &bTrue,
|
||||
IFaceBlackList: []string{"eth0"},
|
||||
DisableIPv6Discovery: true,
|
||||
RosenpassEnabled: true,
|
||||
RosenpassPermissive: true,
|
||||
ServerSSHAllowed: &bTrue,
|
||||
EnableSSHRoot: &bTrue,
|
||||
EnableSSHSFTP: &bTrue,
|
||||
EnableSSHLocalPortForwarding: &bTrue,
|
||||
EnableSSHRemotePortForwarding: &bTrue,
|
||||
DisableSSHAuth: &bTrue,
|
||||
SSHJWTCacheTTL: &iVal,
|
||||
DisableClientRoutes: true,
|
||||
DisableServerRoutes: true,
|
||||
DisableDNS: true,
|
||||
DisableFirewall: true,
|
||||
BlockLANAccess: true,
|
||||
BlockInbound: true,
|
||||
DisableNotifications: &bTrue,
|
||||
DNSLabels: domain.List{},
|
||||
SSHKey: "sshkey",
|
||||
NATExternalIPs: []string{"1.2.3.4"},
|
||||
CustomDNSAddress: "1.1.1.1:53",
|
||||
DisableAutoConnect: true,
|
||||
DNSRouteInterval: 5 * time.Second,
|
||||
ClientCertPath: "/tmp/cert",
|
||||
ClientCertKeyPath: "/tmp/key",
|
||||
LazyConnectionEnabled: true,
|
||||
MTU: 1280,
|
||||
}
|
||||
|
||||
for _, anonymize := range []bool{false, true} {
|
||||
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
|
||||
g := &BundleGenerator{
|
||||
anonymizer: newAnonymizerForTest(),
|
||||
internalConfig: cfg,
|
||||
anonymize: anonymize,
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
g.addCommonConfigFields(&sb)
|
||||
rendered := sb.String() + renderAddConfigSpecific(g)
|
||||
|
||||
val := reflect.ValueOf(cfg).Elem()
|
||||
typ := val.Type()
|
||||
var missing []string
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
name := typ.Field(i).Name
|
||||
if _, ok := excluded[name]; ok {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(rendered, name+":") {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
|
||||
"Either render the field in addCommonConfigFields/addConfig, "+
|
||||
"or add it to the excluded map with a justification.", missing)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
|
||||
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
|
||||
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
|
||||
// production shape without needing to write an actual zip.
|
||||
func renderAddConfigSpecific(g *BundleGenerator) string {
|
||||
var sb strings.Builder
|
||||
if g.anonymize {
|
||||
if g.internalConfig.ManagementURL != nil {
|
||||
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
|
||||
}
|
||||
if g.internalConfig.AdminURL != nil {
|
||||
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
|
||||
}
|
||||
sb.WriteString("NATExternalIPs: x\n")
|
||||
if g.internalConfig.CustomDNSAddress != "" {
|
||||
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
|
||||
}
|
||||
} else {
|
||||
if g.internalConfig.ManagementURL != nil {
|
||||
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
|
||||
}
|
||||
if g.internalConfig.AdminURL != nil {
|
||||
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
|
||||
}
|
||||
sb.WriteString("NATExternalIPs: x\n")
|
||||
if g.internalConfig.CustomDNSAddress != "" {
|
||||
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func newAnonymizerForTest() *anonymize.Anonymizer {
|
||||
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -28,6 +27,10 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||
|
||||
// Start begins monitoring the WireGuard interface.
|
||||
// It relies on the provided context cancellation to stop.
|
||||
//
|
||||
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
|
||||
// to avoid the allocation churn of repeatedly dumping the kernel link
|
||||
// table; on other platforms it falls back to a low-frequency poll.
|
||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||
defer close(m.done)
|
||||
|
||||
@@ -56,31 +59,7 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
||||
|
||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return watchInterface(ctx, ifaceName, expectedIndex)
|
||||
}
|
||||
|
||||
// getInterfaceIndex returns the index of a network interface by name.
|
||||
|
||||
134
client/internal/wg_iface_monitor_linux.go
Normal file
134
client/internal/wg_iface_monitor_linux.go
Normal file
@@ -0,0 +1,134 @@
|
||||
//go:build linux
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
|
||||
// deletion or recreation of the WireGuard interface.
|
||||
//
|
||||
// The previous implementation polled net.InterfaceByName every 2 s, which
|
||||
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
|
||||
// entire kernel link table on every call. On hosts with many veth
|
||||
// interfaces (containers, bridges) the resulting allocation churn was on
|
||||
// the order of ~1 GB/day from this single ticker, which on small ARM
|
||||
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
|
||||
//
|
||||
// The event-driven version below allocates only when the kernel actually
|
||||
// publishes a link event for the tracked interface — typically zero
|
||||
// allocations between events.
|
||||
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
// Buffer the channel to absorb event bursts (e.g. when many veth
|
||||
// pairs are created/destroyed at once by container runtimes).
|
||||
linkChan := make(chan netlink.LinkUpdate, 32)
|
||||
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||
// Return shouldRestart=true so the engine recovers monitoring
|
||||
// via triggerClientRestart instead of silently losing it for
|
||||
// the rest of the process lifetime.
|
||||
return true, fmt.Errorf("subscribe to link updates: %w", err)
|
||||
}
|
||||
|
||||
// Race window: the interface could have been deleted (or recreated)
|
||||
// between the initial getInterfaceIndex() in Start and LinkSubscribe
|
||||
// completing its handshake with the kernel. Re-check explicitly so we
|
||||
// do not block forever waiting for an event that already fired.
|
||||
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
|
||||
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
} else if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
||||
|
||||
case update, ok := <-linkChan:
|
||||
if !ok {
|
||||
// The vishvananda/netlink subscription goroutine closes
|
||||
// the channel on receive errors. Signal the engine to
|
||||
// restart so monitoring is re-established instead of
|
||||
// silently ending.
|
||||
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
|
||||
return true, fmt.Errorf("link subscription channel closed unexpectedly")
|
||||
}
|
||||
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inspectLinkEvent classifies a single netlink link update against the
|
||||
// tracked WireGuard interface. It returns (true, err) when the engine
|
||||
// should restart monitoring; (false, nil) means the event is unrelated
|
||||
// and the caller should keep waiting.
|
||||
//
|
||||
// The error component, when non-nil, describes the kernel-side reason
|
||||
// (deletion or rename); the recreation case returns (true, nil) since
|
||||
// no error condition is reported.
|
||||
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
|
||||
eventIndex := int(update.Index)
|
||||
eventName := ""
|
||||
if attrs := update.Attrs(); attrs != nil {
|
||||
eventName = attrs.Name
|
||||
}
|
||||
|
||||
switch update.Header.Type {
|
||||
case syscall.RTM_DELLINK:
|
||||
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
|
||||
case syscall.RTM_NEWLINK:
|
||||
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
|
||||
// tracked interface index.
|
||||
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
|
||||
if eventIndex != expectedIndex {
|
||||
return false, nil
|
||||
}
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted", ifaceName)
|
||||
}
|
||||
|
||||
// inspectNewLink reports a restart when an RTM_NEWLINK either:
|
||||
//
|
||||
// 1. Introduces a link with our name at a different index (recreation
|
||||
// after a delete), or
|
||||
//
|
||||
// 2. Reports a link still at our index but with a different name
|
||||
// (in-place rename). The previous polling implementation caught
|
||||
// this implicitly because net.InterfaceByName(ifaceName) would
|
||||
// start failing; the event-driven version has to test it.
|
||||
//
|
||||
// Same name + same index is just a flag/state change on the existing
|
||||
// interface and is ignored.
|
||||
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
|
||||
if eventName == ifaceName && eventIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, eventIndex)
|
||||
return true, nil
|
||||
}
|
||||
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
|
||||
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
|
||||
ifaceName, eventName, expectedIndex)
|
||||
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
56
client/internal/wg_iface_monitor_other.go
Normal file
56
client/internal/wg_iface_monitor_other.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build !linux
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// watchInterface polls net.InterfaceByName at a fixed interval to detect
|
||||
// deletion or recreation of the WireGuard interface.
|
||||
//
|
||||
// This is the fallback used on non-Linux desktop and server platforms
|
||||
// (darwin, windows, freebsd). It is also compiled on android and ios so
|
||||
// the package builds on every supported GOOS, but it is never reached
|
||||
// at runtime there because Start() in wg_iface_monitor.go exits early
|
||||
// on mobile platforms.
|
||||
//
|
||||
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
|
||||
// RTNLGRP_LINK netlink subscription instead, because on Linux
|
||||
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
|
||||
// dumps the entire kernel link table on every call and produces
|
||||
// significant allocation churn (netbirdio/netbird#3678).
|
||||
//
|
||||
// Windows is also reported in #3678 as affected by RSS climb. A future
|
||||
// follow-up could implement an event-driven watcher there using
|
||||
// NotifyIpInterfaceChange from iphlpapi.
|
||||
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -224,15 +224,20 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
sshConfigPathTmp := sshConfigPath + ".tmp"
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
|
||||
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -71,6 +71,7 @@ require (
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||
@@ -309,6 +310,7 @@ require (
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
rsc.io/qr v0.2.0 // indirect
|
||||
)
|
||||
|
||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
||||
|
||||
4
go.sum
4
go.sum
@@ -415,6 +415,8 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
|
||||
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
|
||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
||||
@@ -915,3 +917,5 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
||||
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
||||
Heartbeat(ctx context.Context, p *Proxy) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
@@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
@@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
}
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
SessionID: sessionID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
@@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"sessionID": sessionID,
|
||||
"clusterAddress": clusterAddress,
|
||||
"ipAddress": ipAddress,
|
||||
}).Info("proxy connected")
|
||||
|
||||
return nil
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
// Disconnect marks a proxy as disconnected in the database.
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"proxyID": proxyID,
|
||||
"sessionID": sessionID,
|
||||
}).Info("proxy disconnected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
// Heartbeat updates the proxy's last seen timestamp.
|
||||
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
|
||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID)
|
||||
m.metrics.IncrementProxyHeartbeatCount()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Disconnect indicates an expected call of Disconnect.
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses mocks base method.
|
||||
@@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, p)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Heartbeat indicates an expected call of Heartbeat.
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
|
||||
@@ -18,12 +18,13 @@ type Capabilities struct {
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
SessionID string `gorm:"type:varchar(36)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Capabilities Capabilities `gorm:"embedded"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -89,6 +90,7 @@ const pkceVerifierTTL = 10 * time.Minute
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
sessionID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
@@ -166,9 +168,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": sessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
@@ -188,12 +203,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
s.connectedProxies.CompareAndDelete(proxyID, conn)
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
@@ -202,22 +218,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s disconnected", proxyID)
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
@@ -227,29 +248,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
// Start heartbeat goroutine
|
||||
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
|
||||
go s.heartbeat(connCtx, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
case <-connCtx.Done():
|
||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||
return connCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5437,13 +5437,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
// DisconnectProxy marks a proxy as disconnected only if the session ID matches.
|
||||
// This prevents a slow-to-close old session from overwriting a newer reconnection.
|
||||
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
now := time.Now()
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND session_id = ?", proxyID, sessionID).
|
||||
Updates(map[string]any{
|
||||
"status": "disconnected",
|
||||
"disconnected_at": now,
|
||||
"last_seen": now,
|
||||
})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error)
|
||||
return status.Errorf(status.Internal, "failed to disconnect proxy")
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
now := time.Now()
|
||||
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
|
||||
Update("last_seen", now)
|
||||
|
||||
if result.Error != nil {
|
||||
@@ -5452,17 +5474,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
}
|
||||
if err := s.db.Save(p).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
|
||||
p.LastSeen = now
|
||||
p.ConnectedAt = &now
|
||||
p.Status = "connected"
|
||||
if err := s.db.Create(p).Error; err != nil {
|
||||
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -284,7 +284,8 @@ type Store interface {
|
||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||
|
||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
|
||||
@@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2799,6 +2800,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||
}
|
||||
|
||||
// DisconnectProxy mocks base method.
|
||||
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken mocks base method.
|
||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2995,17 +3010,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p)
|
||||
}
|
||||
|
||||
// UpdateService mocks base method.
|
||||
|
||||
@@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||
type testProxyManager struct{}
|
||||
|
||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error {
|
||||
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
||||
func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -656,3 +656,72 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
||||
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that
|
||||
// when a proxy reconnects before the old stream's cleanup runs, the new
|
||||
// connection is NOT removed by the stale defer.
|
||||
func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-race"
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := stream1.Recv()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||
"proxy should be registered after first connection")
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := stream2.Recv()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
cancel1()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||
"proxy should still be registered after old connection cleanup — old defer must not remove new connection")
|
||||
|
||||
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||
Id: "rp-1",
|
||||
AccountId: "test-account-1",
|
||||
Domain: "app1.test.proxy.io",
|
||||
}},
|
||||
})
|
||||
|
||||
msg, err := stream2.Recv()
|
||||
require.NoError(t, err, "new stream should still receive updates")
|
||||
require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping")
|
||||
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user