mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-06 00:56:39 +00:00
Compare commits
5 Commits
ssh-config
...
fix/debug-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cb25de9ea | ||
|
|
97db824929 | ||
|
|
77a0992dc2 | ||
|
|
104990dfdd | ||
|
|
bde632c3b2 |
@@ -58,6 +58,11 @@ linters:
|
|||||||
govet:
|
govet:
|
||||||
enable:
|
enable:
|
||||||
- nilness
|
- nilness
|
||||||
|
disable:
|
||||||
|
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
|
||||||
|
# directives but cannot perform the rewrite due to generic type
|
||||||
|
# parameter inference limitations in the Go inliner.
|
||||||
|
- inline
|
||||||
enable-all: false
|
enable-all: false
|
||||||
revive:
|
revive:
|
||||||
rules:
|
rules:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/term"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
@@ -23,6 +24,7 @@ import (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||||
}
|
}
|
||||||
@@ -256,7 +258,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
|
||||||
|
|
||||||
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -324,7 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
|
||||||
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -334,7 +336,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
|
||||||
var codeMsg string
|
var codeMsg string
|
||||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
@@ -348,6 +350,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
verificationURIComplete + " " + codeMsg)
|
verificationURIComplete + " " + codeMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if showQR {
|
||||||
|
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
|
||||||
|
printQRCode(f, verificationURIComplete)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
if !noBrowser {
|
if !noBrowser {
|
||||||
|
|||||||
25
client/cmd/qr.go
Normal file
25
client/cmd/qr.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/mdp/qrterminal/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// printQRCode prints a QR code for the given URL to the writer.
|
||||||
|
// Called only when the user explicitly requests QR output via --qr.
|
||||||
|
func printQRCode(w io.Writer, url string) {
|
||||||
|
if url == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
qrterminal.GenerateWithConfig(url, qrterminal.Config{
|
||||||
|
Level: qrterminal.M,
|
||||||
|
Writer: w,
|
||||||
|
HalfBlocks: true,
|
||||||
|
BlackChar: qrterminal.BLACK_BLACK,
|
||||||
|
WhiteChar: qrterminal.WHITE_WHITE,
|
||||||
|
BlackWhiteChar: qrterminal.BLACK_WHITE,
|
||||||
|
WhiteBlackChar: qrterminal.WHITE_BLACK,
|
||||||
|
QuietZone: qrterminal.QUIET_ZONE,
|
||||||
|
})
|
||||||
|
}
|
||||||
26
client/cmd/qr_test.go
Normal file
26
client/cmd/qr_test.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrintQRCode_EmptyURL(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
printQRCode(&buf, "")
|
||||||
|
|
||||||
|
if buf.Len() != 0 {
|
||||||
|
t.Error("expected no output for empty URL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrintQRCode_WritesOutput(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
printQRCode(&buf, "https://example.com/auth")
|
||||||
|
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
t.Error("expected QR code output for non-empty URL")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -39,6 +39,9 @@ const (
|
|||||||
noBrowserFlag = "no-browser"
|
noBrowserFlag = "no-browser"
|
||||||
noBrowserDesc = "do not open the browser for SSO login"
|
noBrowserDesc = "do not open the browser for SSO login"
|
||||||
|
|
||||||
|
showQRFlag = "qr"
|
||||||
|
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
|
||||||
|
|
||||||
profileNameFlag = "profile"
|
profileNameFlag = "profile"
|
||||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||||
)
|
)
|
||||||
@@ -48,6 +51,7 @@ var (
|
|||||||
dnsLabels []string
|
dnsLabels []string
|
||||||
dnsLabelsValidated domain.List
|
dnsLabelsValidated domain.List
|
||||||
noBrowser bool
|
noBrowser bool
|
||||||
|
showQR bool
|
||||||
profileName string
|
profileName string
|
||||||
configPath string
|
configPath string
|
||||||
|
|
||||||
@@ -80,6 +84,7 @@ func init() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||||
|
|
||||||
|
|||||||
@@ -607,6 +607,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||||
}
|
}
|
||||||
|
if g.internalConfig.DisableSSHAuth != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
|
||||||
|
}
|
||||||
|
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||||
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||||
@@ -633,6 +639,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
}
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||||
|
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addProf() (err error) {
|
func (g *BundleGenerator) addProf() (err error) {
|
||||||
|
|||||||
@@ -5,16 +5,21 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/configs"
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
|||||||
assert.Contains(t, anonNftables, "chain input {")
|
assert.Contains(t, anonNftables, "chain input {")
|
||||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
|
||||||
|
// profilemanager.Config is either rendered in the debug bundle or explicitly
|
||||||
|
// excluded. When a new field is added to Config, this test fails until the
|
||||||
|
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
|
||||||
|
// the excluded set with a justification.
|
||||||
|
func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||||
|
excluded := map[string]string{
|
||||||
|
"PrivateKey": "sensitive: WireGuard private key",
|
||||||
|
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||||
|
"SSHKey": "sensitive: SSH private key",
|
||||||
|
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||||
|
}
|
||||||
|
|
||||||
|
mURL, _ := url.Parse("https://api.example.com:443")
|
||||||
|
aURL, _ := url.Parse("https://admin.example.com:443")
|
||||||
|
bTrue := true
|
||||||
|
iVal := 42
|
||||||
|
cfg := &profilemanager.Config{
|
||||||
|
PrivateKey: "priv",
|
||||||
|
PreSharedKey: "psk",
|
||||||
|
ManagementURL: mURL,
|
||||||
|
AdminURL: aURL,
|
||||||
|
WgIface: "wt0",
|
||||||
|
WgPort: 51820,
|
||||||
|
NetworkMonitor: &bTrue,
|
||||||
|
IFaceBlackList: []string{"eth0"},
|
||||||
|
DisableIPv6Discovery: true,
|
||||||
|
RosenpassEnabled: true,
|
||||||
|
RosenpassPermissive: true,
|
||||||
|
ServerSSHAllowed: &bTrue,
|
||||||
|
EnableSSHRoot: &bTrue,
|
||||||
|
EnableSSHSFTP: &bTrue,
|
||||||
|
EnableSSHLocalPortForwarding: &bTrue,
|
||||||
|
EnableSSHRemotePortForwarding: &bTrue,
|
||||||
|
DisableSSHAuth: &bTrue,
|
||||||
|
SSHJWTCacheTTL: &iVal,
|
||||||
|
DisableClientRoutes: true,
|
||||||
|
DisableServerRoutes: true,
|
||||||
|
DisableDNS: true,
|
||||||
|
DisableFirewall: true,
|
||||||
|
BlockLANAccess: true,
|
||||||
|
BlockInbound: true,
|
||||||
|
DisableNotifications: &bTrue,
|
||||||
|
DNSLabels: domain.List{},
|
||||||
|
SSHKey: "sshkey",
|
||||||
|
NATExternalIPs: []string{"1.2.3.4"},
|
||||||
|
CustomDNSAddress: "1.1.1.1:53",
|
||||||
|
DisableAutoConnect: true,
|
||||||
|
DNSRouteInterval: 5 * time.Second,
|
||||||
|
ClientCertPath: "/tmp/cert",
|
||||||
|
ClientCertKeyPath: "/tmp/key",
|
||||||
|
LazyConnectionEnabled: true,
|
||||||
|
MTU: 1280,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, anonymize := range []bool{false, true} {
|
||||||
|
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
|
||||||
|
g := &BundleGenerator{
|
||||||
|
anonymizer: newAnonymizerForTest(),
|
||||||
|
internalConfig: cfg,
|
||||||
|
anonymize: anonymize,
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
g.addCommonConfigFields(&sb)
|
||||||
|
rendered := sb.String() + renderAddConfigSpecific(g)
|
||||||
|
|
||||||
|
val := reflect.ValueOf(cfg).Elem()
|
||||||
|
typ := val.Type()
|
||||||
|
var missing []string
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
name := typ.Field(i).Name
|
||||||
|
if _, ok := excluded[name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.Contains(rendered, name+":") {
|
||||||
|
missing = append(missing, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
|
||||||
|
"Either render the field in addCommonConfigFields/addConfig, "+
|
||||||
|
"or add it to the excluded map with a justification.", missing)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
|
||||||
|
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
|
||||||
|
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
|
||||||
|
// production shape without needing to write an actual zip.
|
||||||
|
func renderAddConfigSpecific(g *BundleGenerator) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
if g.anonymize {
|
||||||
|
if g.internalConfig.ManagementURL != nil {
|
||||||
|
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
|
||||||
|
}
|
||||||
|
if g.internalConfig.AdminURL != nil {
|
||||||
|
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("NATExternalIPs: x\n")
|
||||||
|
if g.internalConfig.CustomDNSAddress != "" {
|
||||||
|
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if g.internalConfig.ManagementURL != nil {
|
||||||
|
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
|
||||||
|
}
|
||||||
|
if g.internalConfig.AdminURL != nil {
|
||||||
|
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("NATExternalIPs: x\n")
|
||||||
|
if g.internalConfig.CustomDNSAddress != "" {
|
||||||
|
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAnonymizerForTest() *anonymize.Anonymizer {
|
||||||
|
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -28,6 +27,10 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
|
|||||||
|
|
||||||
// Start begins monitoring the WireGuard interface.
|
// Start begins monitoring the WireGuard interface.
|
||||||
// It relies on the provided context cancellation to stop.
|
// It relies on the provided context cancellation to stop.
|
||||||
|
//
|
||||||
|
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
|
||||||
|
// to avoid the allocation churn of repeatedly dumping the kernel link
|
||||||
|
// table; on other platforms it falls back to a low-frequency poll.
|
||||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||||
defer close(m.done)
|
defer close(m.done)
|
||||||
|
|
||||||
@@ -56,31 +59,7 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
|||||||
|
|
||||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||||
|
|
||||||
ticker := time.NewTicker(2 * time.Second)
|
return watchInterface(ctx, ifaceName, expectedIndex)
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
|
||||||
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
|
||||||
case <-ticker.C:
|
|
||||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
|
||||||
if err != nil {
|
|
||||||
// Interface was deleted
|
|
||||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
|
||||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if interface index changed (interface was recreated)
|
|
||||||
if currentIndex != expectedIndex {
|
|
||||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
|
||||||
ifaceName, expectedIndex, currentIndex)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getInterfaceIndex returns the index of a network interface by name.
|
// getInterfaceIndex returns the index of a network interface by name.
|
||||||
|
|||||||
134
client/internal/wg_iface_monitor_linux.go
Normal file
134
client/internal/wg_iface_monitor_linux.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
|
||||||
|
// deletion or recreation of the WireGuard interface.
|
||||||
|
//
|
||||||
|
// The previous implementation polled net.InterfaceByName every 2 s, which
|
||||||
|
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
|
||||||
|
// entire kernel link table on every call. On hosts with many veth
|
||||||
|
// interfaces (containers, bridges) the resulting allocation churn was on
|
||||||
|
// the order of ~1 GB/day from this single ticker, which on small ARM
|
||||||
|
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
|
||||||
|
//
|
||||||
|
// The event-driven version below allocates only when the kernel actually
|
||||||
|
// publishes a link event for the tracked interface — typically zero
|
||||||
|
// allocations between events.
|
||||||
|
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
// Buffer the channel to absorb event bursts (e.g. when many veth
|
||||||
|
// pairs are created/destroyed at once by container runtimes).
|
||||||
|
linkChan := make(chan netlink.LinkUpdate, 32)
|
||||||
|
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||||
|
// Return shouldRestart=true so the engine recovers monitoring
|
||||||
|
// via triggerClientRestart instead of silently losing it for
|
||||||
|
// the rest of the process lifetime.
|
||||||
|
return true, fmt.Errorf("subscribe to link updates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Race window: the interface could have been deleted (or recreated)
|
||||||
|
// between the initial getInterfaceIndex() in Start and LinkSubscribe
|
||||||
|
// completing its handshake with the kernel. Re-check explicitly so we
|
||||||
|
// do not block forever waiting for an event that already fired.
|
||||||
|
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
|
||||||
|
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
|
||||||
|
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||||
|
} else if currentIndex != expectedIndex {
|
||||||
|
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
|
||||||
|
ifaceName, expectedIndex, currentIndex)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||||
|
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
||||||
|
|
||||||
|
case update, ok := <-linkChan:
|
||||||
|
if !ok {
|
||||||
|
// The vishvananda/netlink subscription goroutine closes
|
||||||
|
// the channel on receive errors. Signal the engine to
|
||||||
|
// restart so monitoring is re-established instead of
|
||||||
|
// silently ending.
|
||||||
|
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
|
||||||
|
return true, fmt.Errorf("link subscription channel closed unexpectedly")
|
||||||
|
}
|
||||||
|
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inspectLinkEvent classifies a single netlink link update against the
|
||||||
|
// tracked WireGuard interface. It returns (true, err) when the engine
|
||||||
|
// should restart monitoring; (false, nil) means the event is unrelated
|
||||||
|
// and the caller should keep waiting.
|
||||||
|
//
|
||||||
|
// The error component, when non-nil, describes the kernel-side reason
|
||||||
|
// (deletion or rename); the recreation case returns (true, nil) since
|
||||||
|
// no error condition is reported.
|
||||||
|
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
|
||||||
|
eventIndex := int(update.Index)
|
||||||
|
eventName := ""
|
||||||
|
if attrs := update.Attrs(); attrs != nil {
|
||||||
|
eventName = attrs.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
switch update.Header.Type {
|
||||||
|
case syscall.RTM_DELLINK:
|
||||||
|
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
|
||||||
|
case syscall.RTM_NEWLINK:
|
||||||
|
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
|
||||||
|
// tracked interface index.
|
||||||
|
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
|
||||||
|
if eventIndex != expectedIndex {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||||
|
return true, fmt.Errorf("interface %s deleted", ifaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// inspectNewLink reports a restart when an RTM_NEWLINK either:
|
||||||
|
//
|
||||||
|
// 1. Introduces a link with our name at a different index (recreation
|
||||||
|
// after a delete), or
|
||||||
|
//
|
||||||
|
// 2. Reports a link still at our index but with a different name
|
||||||
|
// (in-place rename). The previous polling implementation caught
|
||||||
|
// this implicitly because net.InterfaceByName(ifaceName) would
|
||||||
|
// start failing; the event-driven version has to test it.
|
||||||
|
//
|
||||||
|
// Same name + same index is just a flag/state change on the existing
|
||||||
|
// interface and is ignored.
|
||||||
|
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
|
||||||
|
if eventName == ifaceName && eventIndex != expectedIndex {
|
||||||
|
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||||
|
ifaceName, expectedIndex, eventIndex)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
|
||||||
|
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
|
||||||
|
ifaceName, eventName, expectedIndex)
|
||||||
|
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
56
client/internal/wg_iface_monitor_other.go
Normal file
56
client/internal/wg_iface_monitor_other.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// watchInterface polls net.InterfaceByName at a fixed interval to detect
|
||||||
|
// deletion or recreation of the WireGuard interface.
|
||||||
|
//
|
||||||
|
// This is the fallback used on non-Linux desktop and server platforms
|
||||||
|
// (darwin, windows, freebsd). It is also compiled on android and ios so
|
||||||
|
// the package builds on every supported GOOS, but it is never reached
|
||||||
|
// at runtime there because Start() in wg_iface_monitor.go exits early
|
||||||
|
// on mobile platforms.
|
||||||
|
//
|
||||||
|
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
|
||||||
|
// RTNLGRP_LINK netlink subscription instead, because on Linux
|
||||||
|
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
|
||||||
|
// dumps the entire kernel link table on every call and produces
|
||||||
|
// significant allocation churn (netbirdio/netbird#3678).
|
||||||
|
//
|
||||||
|
// Windows is also reported in #3678 as affected by RSS climb. A future
|
||||||
|
// follow-up could implement an event-driven watcher there using
|
||||||
|
// NotifyIpInterfaceChange from iphlpapi.
|
||||||
|
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||||
|
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
||||||
|
case <-ticker.C:
|
||||||
|
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
// Interface was deleted
|
||||||
|
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||||
|
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if interface index changed (interface was recreated)
|
||||||
|
if currentIndex != expectedIndex {
|
||||||
|
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||||
|
ifaceName, expectedIndex, currentIndex)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
2
go.mod
2
go.mod
@@ -71,6 +71,7 @@ require (
|
|||||||
github.com/libp2p/go-netroute v0.2.1
|
github.com/libp2p/go-netroute v0.2.1
|
||||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
|
github.com/mdp/qrterminal/v3 v3.2.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||||
@@ -309,6 +310,7 @@ require (
|
|||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
|
rsc.io/qr v0.2.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -415,6 +415,8 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0
|
|||||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
|
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
|
||||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||||
|
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
|
||||||
|
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
|
||||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||||
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
||||||
@@ -915,3 +917,5 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
|||||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||||
|
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||||
|
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
|
|
||||||
// Manager defines the interface for proxy operations
|
// Manager defines the interface for proxy operations
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
||||||
Disconnect(ctx context.Context, proxyID string) error
|
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
||||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
Heartbeat(ctx context.Context, p *Proxy) error
|
||||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ import (
|
|||||||
// store defines the interface for proxy persistence operations
|
// store defines the interface for proxy persistence operations
|
||||||
type store interface {
|
type store interface {
|
||||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||||
|
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
@@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
|||||||
|
|
||||||
// Connect registers a new proxy connection in the database.
|
// Connect registers a new proxy connection in the database.
|
||||||
// capabilities may be nil for old proxies that do not report them.
|
// capabilities may be nil for old proxies that do not report them.
|
||||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
var caps proxy.Capabilities
|
var caps proxy.Capabilities
|
||||||
if capabilities != nil {
|
if capabilities != nil {
|
||||||
@@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
|||||||
}
|
}
|
||||||
p := &proxy.Proxy{
|
p := &proxy.Proxy{
|
||||||
ID: proxyID,
|
ID: proxyID,
|
||||||
|
SessionID: sessionID,
|
||||||
ClusterAddress: clusterAddress,
|
ClusterAddress: clusterAddress,
|
||||||
IPAddress: ipAddress,
|
IPAddress: ipAddress,
|
||||||
LastSeen: now,
|
LastSeen: now,
|
||||||
@@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
|||||||
|
|
||||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).WithFields(log.Fields{
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
"proxyID": proxyID,
|
"proxyID": proxyID,
|
||||||
|
"sessionID": sessionID,
|
||||||
"clusterAddress": clusterAddress,
|
"clusterAddress": clusterAddress,
|
||||||
"ipAddress": ipAddress,
|
"ipAddress": ipAddress,
|
||||||
}).Info("proxy connected")
|
}).Info("proxy connected")
|
||||||
|
|
||||||
return nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect marks a proxy as disconnected in the database
|
// Disconnect marks a proxy as disconnected in the database.
|
||||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||||
now := time.Now()
|
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||||
p := &proxy.Proxy{
|
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||||
ID: proxyID,
|
|
||||||
Status: "disconnected",
|
|
||||||
DisconnectedAt: &now,
|
|
||||||
LastSeen: now,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).WithFields(log.Fields{
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
"proxyID": proxyID,
|
"proxyID": proxyID,
|
||||||
|
"sessionID": sessionID,
|
||||||
}).Info("proxy disconnected")
|
}).Info("proxy disconnected")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat updates the proxy's last seen timestamp
|
// Heartbeat updates the proxy's last seen timestamp.
|
||||||
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
|
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID)
|
||||||
m.metrics.IncrementProxyHeartbeatCount()
|
m.metrics.IncrementProxyHeartbeatCount()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect mocks base method.
|
// Connect mocks base method.
|
||||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(*Proxy)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect indicates an expected call of Connect.
|
// Connect indicates an expected call of Connect.
|
||||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect mocks base method.
|
// Disconnect mocks base method.
|
||||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect indicates an expected call of Disconnect.
|
// Disconnect indicates an expected call of Disconnect.
|
||||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusterAddresses mocks base method.
|
// GetActiveClusterAddresses mocks base method.
|
||||||
@@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat mocks base method.
|
// Heartbeat mocks base method.
|
||||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
ret := m.ctrl.Call(m, "Heartbeat", ctx, p)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat indicates an expected call of Heartbeat.
|
// Heartbeat indicates an expected call of Heartbeat.
|
||||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockController is a mock of Controller interface.
|
// MockController is a mock of Controller interface.
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type Capabilities struct {
|
|||||||
// Proxy represents a reverse proxy instance
|
// Proxy represents a reverse proxy instance
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||||
|
SessionID string `gorm:"type:varchar(36)"`
|
||||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||||
IPAddress string `gorm:"type:varchar(45)"`
|
IPAddress string `gorm:"type:varchar(45)"`
|
||||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@@ -89,6 +90,7 @@ const pkceVerifierTTL = 10 * time.Minute
|
|||||||
// proxyConnection represents a connected proxy
|
// proxyConnection represents a connected proxy
|
||||||
type proxyConnection struct {
|
type proxyConnection struct {
|
||||||
proxyID string
|
proxyID string
|
||||||
|
sessionID string
|
||||||
address string
|
address string
|
||||||
capabilities *proto.ProxyCapabilities
|
capabilities *proto.ProxyCapabilities
|
||||||
stream proto.ProxyService_GetMappingUpdateServer
|
stream proto.ProxyService_GetMappingUpdateServer
|
||||||
@@ -166,9 +168,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionID := uuid.NewString()
|
||||||
|
|
||||||
|
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||||
|
oldConn := old.(*proxyConnection)
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"proxy_id": proxyID,
|
||||||
|
"old_session_id": oldConn.sessionID,
|
||||||
|
"new_session_id": sessionID,
|
||||||
|
}).Info("Superseding existing proxy connection")
|
||||||
|
oldConn.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithCancel(ctx)
|
connCtx, cancel := context.WithCancel(ctx)
|
||||||
conn := &proxyConnection{
|
conn := &proxyConnection{
|
||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
|
sessionID: sessionID,
|
||||||
address: proxyAddress,
|
address: proxyAddress,
|
||||||
capabilities: req.GetCapabilities(),
|
capabilities: req.GetCapabilities(),
|
||||||
stream: stream,
|
stream: stream,
|
||||||
@@ -191,9 +206,10 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
|
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||||
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||||
s.connectedProxies.Delete(proxyID)
|
s.connectedProxies.CompareAndDelete(proxyID, conn)
|
||||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||||
}
|
}
|
||||||
@@ -202,22 +218,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"proxy_id": proxyID,
|
"proxy_id": proxyID,
|
||||||
|
"session_id": sessionID,
|
||||||
"address": proxyAddress,
|
"address": proxyAddress,
|
||||||
"cluster_addr": proxyAddress,
|
"cluster_addr": proxyAddress,
|
||||||
"total_proxies": len(s.GetConnectedProxies()),
|
"total_proxies": len(s.GetConnectedProxies()),
|
||||||
}).Info("Proxy registered in cluster")
|
}).Info("Proxy registered in cluster")
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Delete(proxyID)
|
|
||||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||||
}
|
}
|
||||||
|
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||||
|
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
log.Infof("Proxy %s disconnected", proxyID)
|
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||||
@@ -227,29 +248,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
go s.sender(conn, errChan)
|
go s.sender(conn, errChan)
|
||||||
|
|
||||||
// Start heartbeat goroutine
|
go s.heartbeat(connCtx, proxyRecord)
|
||||||
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
|
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||||
case <-connCtx.Done():
|
case <-connCtx.Done():
|
||||||
|
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||||
return connCtx.Err()
|
return connCtx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
|
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
|
||||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5437,13 +5437,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
|
// DisconnectProxy marks a proxy as disconnected only if the session ID matches.
|
||||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
// This prevents a slow-to-close old session from overwriting a newer reconnection.
|
||||||
|
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||||
|
now := time.Now()
|
||||||
|
result := s.db.
|
||||||
|
Model(&proxy.Proxy{}).
|
||||||
|
Where("id = ? AND session_id = ?", proxyID, sessionID).
|
||||||
|
Updates(map[string]any{
|
||||||
|
"status": "disconnected",
|
||||||
|
"disconnected_at": now,
|
||||||
|
"last_seen": now,
|
||||||
|
})
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to disconnect proxy")
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
|
||||||
|
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
result := s.db.
|
result := s.db.
|
||||||
Model(&proxy.Proxy{}).
|
Model(&proxy.Proxy{}).
|
||||||
Where("id = ? AND status = ?", proxyID, "connected").
|
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
|
||||||
Update("last_seen", now)
|
Update("last_seen", now)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@@ -5452,17 +5474,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
|||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected == 0 {
|
if result.RowsAffected == 0 {
|
||||||
p := &proxy.Proxy{
|
p.LastSeen = now
|
||||||
ID: proxyID,
|
p.ConnectedAt = &now
|
||||||
ClusterAddress: clusterAddress,
|
p.Status = "connected"
|
||||||
IPAddress: ipAddress,
|
if err := s.db.Create(p).Error; err != nil {
|
||||||
LastSeen: now,
|
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
|
||||||
ConnectedAt: &now,
|
|
||||||
Status: "connected",
|
|
||||||
}
|
|
||||||
if err := s.db.Save(p).Error; err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
|
|
||||||
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -284,7 +284,8 @@ type Store interface {
|
|||||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||||
|
|
||||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||||
|
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close mocks base method.
|
// Close mocks base method.
|
||||||
func (m *MockStore) Close(ctx context.Context) error {
|
func (m *MockStore) Close(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -2799,6 +2800,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DisconnectProxy mocks base method.
|
||||||
|
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||||
|
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
// SaveProxyAccessToken mocks base method.
|
// SaveProxyAccessToken mocks base method.
|
||||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -2995,17 +3010,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProxyHeartbeat mocks base method.
|
// UpdateProxyHeartbeat mocks base method.
|
||||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
||||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateService mocks base method.
|
// UpdateService mocks base method.
|
||||||
|
|||||||
@@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
|||||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||||
type testProxyManager struct{}
|
type testProxyManager struct{}
|
||||||
|
|
||||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error {
|
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||||
|
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
|
func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error {
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -656,3 +656,72 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
|||||||
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that
|
||||||
|
// when a proxy reconnects before the old stream's cleanup runs, the new
|
||||||
|
// connection is NOT removed by the stale defer.
|
||||||
|
func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) {
|
||||||
|
setup := setupIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
clusterAddress := "test.proxy.io"
|
||||||
|
proxyID := "test-proxy-race"
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
|
||||||
|
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: proxyID,
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: clusterAddress,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
_, err := stream1.Recv()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||||
|
"proxy should be registered after first connection")
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: proxyID,
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: clusterAddress,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
_, err := stream2.Recv()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel1()
|
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||||
|
"proxy should still be registered after old connection cleanup — old defer must not remove new connection")
|
||||||
|
|
||||||
|
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: []*proto.ProxyMapping{{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||||
|
Id: "rp-1",
|
||||||
|
AccountId: "test-account-1",
|
||||||
|
Domain: "app1.test.proxy.io",
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
msg, err := stream2.Recv()
|
||||||
|
require.NoError(t, err, "new stream should still receive updates")
|
||||||
|
require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping")
|
||||||
|
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user