Compare commits

...

3 Commits

Author SHA1 Message Date
Viktor Liu
6cb25de9ea Include MTU and SSH auth/JWT cache config in debug bundle 2026-05-05 12:18:56 +02:00
Pascal Fischer
97db824929 [management] fix proxy reconnect (#6063) 2026-05-04 20:43:25 +02:00
Viktor Liu
77a0992dc2 [misc] Disable govet inline analyzer and tidy go.mod (#6066) 2026-05-05 02:59:41 +09:00
13 changed files with 346 additions and 83 deletions

View File

@@ -58,6 +58,11 @@ linters:
govet: govet:
enable: enable:
- nilness - nilness
disable:
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
# directives but cannot perform the rewrite due to generic type
# parameter inference limitations in the Go inliner.
- inline
enable-all: false enable-all: false
revive: revive:
rules: rules:

View File

@@ -607,6 +607,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil { if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
} }
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -633,6 +639,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
} }
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {

View File

@@ -5,16 +5,21 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs" "github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -471,8 +476,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false, anonymize: false,
input: map[string]any{ input: map[string]any{
jsonKeyServiceEnv: map[string]any{ jsonKeyServiceEnv: map[string]any{
"HOME": "/root", "HOME": "/root",
"PATH": "/usr/bin", "PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug", "NB_LOG_LEVEL": "debug",
}, },
}, },
@@ -489,9 +494,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false, anonymize: false,
input: map[string]any{ input: map[string]any{
jsonKeyServiceEnv: map[string]any{ jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123", "NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz", "NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info", "NB_LOG_LEVEL": "info",
}, },
}, },
check: func(t *testing.T, params map[string]any) { check: func(t *testing.T, params map[string]any) {
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
} }
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

2
go.mod
View File

@@ -309,8 +309,8 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
rsc.io/qr v0.2.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
rsc.io/qr v0.2.0 // indirect
) )
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502

View File

@@ -11,9 +11,9 @@ import (
// Manager defines the interface for proxy operations // Manager defines the interface for proxy operations
type Manager interface { type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
Disconnect(ctx context.Context, proxyID string) error Disconnect(ctx context.Context, proxyID, sessionID string) error
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error Heartbeat(ctx context.Context, p *Proxy) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error) GetActiveClusters(ctx context.Context) ([]Cluster, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool

View File

@@ -13,7 +13,8 @@ import (
// store defines the interface for proxy persistence operations // store defines the interface for proxy persistence operations
type store interface { type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
@@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database. // Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them. // capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
now := time.Now() now := time.Now()
var caps proxy.Capabilities var caps proxy.Capabilities
if capabilities != nil { if capabilities != nil {
@@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
} }
p := &proxy.Proxy{ p := &proxy.Proxy{
ID: proxyID, ID: proxyID,
SessionID: sessionID,
ClusterAddress: clusterAddress, ClusterAddress: clusterAddress,
IPAddress: ipAddress, IPAddress: ipAddress,
LastSeen: now, LastSeen: now,
@@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
if err := m.store.SaveProxy(ctx, p); err != nil { if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err) log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return err return nil, err
} }
log.WithContext(ctx).WithFields(log.Fields{ log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID, "proxyID": proxyID,
"sessionID": sessionID,
"clusterAddress": clusterAddress, "clusterAddress": clusterAddress,
"ipAddress": ipAddress, "ipAddress": ipAddress,
}).Info("proxy connected") }).Info("proxy connected")
return nil return p, nil
} }
// Disconnect marks a proxy as disconnected in the database // Disconnect marks a proxy as disconnected in the database.
func (m Manager) Disconnect(ctx context.Context, proxyID string) error { func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
now := time.Now() if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
p := &proxy.Proxy{ log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err return err
} }
log.WithContext(ctx).WithFields(log.Fields{ log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID, "proxyID": proxyID,
"sessionID": sessionID,
}).Info("proxy disconnected") }).Info("proxy disconnected")
return nil return nil
} }
// Heartbeat updates the proxy's last seen timestamp // Heartbeat updates the proxy's last seen timestamp.
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
return err return err
} }
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID) log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID)
m.metrics.IncrementProxyHeartbeatCount() m.metrics.IncrementProxyHeartbeatCount()
return nil return nil
} }

View File

@@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
} }
// Connect mocks base method. // Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(*Proxy)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// Connect indicates an expected call of Connect. // Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
} }
// Disconnect mocks base method. // Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error { func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID) ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Disconnect indicates an expected call of Disconnect. // Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID)
} }
// GetActiveClusterAddresses mocks base method. // GetActiveClusterAddresses mocks base method.
@@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
} }
// Heartbeat mocks base method. // Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress) ret := m.ctrl.Call(m, "Heartbeat", ctx, p)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Heartbeat indicates an expected call of Heartbeat. // Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
} }
// MockController is a mock of Controller interface. // MockController is a mock of Controller interface.

View File

@@ -18,12 +18,13 @@ type Capabilities struct {
// Proxy represents a reverse proxy instance // Proxy represents a reverse proxy instance
type Proxy struct { type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"` ID string `gorm:"primaryKey;type:varchar(255)"`
SessionID string `gorm:"type:varchar(36)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"` IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time ConnectedAt *time.Time
DisconnectedAt *time.Time DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"` Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time

View File

@@ -16,6 +16,7 @@ import (
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@@ -89,6 +90,7 @@ const pkceVerifierTTL = 10 * time.Minute
// proxyConnection represents a connected proxy // proxyConnection represents a connected proxy
type proxyConnection struct { type proxyConnection struct {
proxyID string proxyID string
sessionID string
address string address string
capabilities *proto.ProxyCapabilities capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer stream proto.ProxyService_GetMappingUpdateServer
@@ -166,9 +168,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid") return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
} }
sessionID := uuid.NewString()
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
oldConn := old.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"old_session_id": oldConn.sessionID,
"new_session_id": sessionID,
}).Info("Superseding existing proxy connection")
oldConn.cancel()
}
connCtx, cancel := context.WithCancel(ctx) connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{ conn := &proxyConnection{
proxyID: proxyID, proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress, address: proxyAddress,
capabilities: req.GetCapabilities(), capabilities: req.GetCapabilities(),
stream: stream, stream: stream,
@@ -188,12 +203,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
caps = &proxy.Capabilities{ caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts, SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain, RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec, SupportsCrowdsec: c.SupportsCrowdsec,
} }
} }
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.Delete(proxyID) s.connectedProxies.CompareAndDelete(proxyID, conn)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
} }
@@ -202,22 +218,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"proxy_id": proxyID, "proxy_id": proxyID,
"session_id": sessionID,
"address": proxyAddress, "address": proxyAddress,
"cluster_addr": proxyAddress, "cluster_addr": proxyAddress,
"total_proxies": len(s.GetConnectedProxies()), "total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster") }).Info("Proxy registered in cluster")
defer func() { defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil { if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
cancel()
return
} }
s.connectedProxies.Delete(proxyID)
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
} }
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
cancel() cancel()
log.Infof("Proxy %s disconnected", proxyID) log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}() }()
if err := s.sendSnapshot(ctx, conn); err != nil { if err := s.sendSnapshot(ctx, conn); err != nil {
@@ -227,29 +248,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2) errChan := make(chan error, 2)
go s.sender(conn, errChan) go s.sender(conn, errChan)
// Start heartbeat goroutine go s.heartbeat(connCtx, proxyRecord)
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
select { select {
case err := <-errChan: case err := <-errChan:
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", proxyID, err) return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
case <-connCtx.Done(): case <-connCtx.Done():
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
return connCtx.Err() return connCtx.Err()
} }
} }
// heartbeat updates the proxy's last_seen timestamp every minute // heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) { func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
} }
case <-ctx.Done(): case <-ctx.Done():
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
return return
} }
} }

View File

@@ -5437,13 +5437,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
return nil return nil
} }
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist // DisconnectProxy marks a proxy as disconnected only if the session ID matches.
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { // This prevents a slow-to-close old session from overwriting a newer reconnection.
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
now := time.Now()
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", proxyID, sessionID).
Updates(map[string]any{
"status": "disconnected",
"disconnected_at": now,
"last_seen": now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error)
return status.Errorf(status.Internal, "failed to disconnect proxy")
}
if result.RowsAffected == 0 {
log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID)
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
now := time.Now() now := time.Now()
result := s.db. result := s.db.
Model(&proxy.Proxy{}). Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected"). Where("id = ? AND session_id = ?", p.ID, p.SessionID).
Update("last_seen", now) Update("last_seen", now)
if result.Error != nil { if result.Error != nil {
@@ -5452,17 +5474,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
p := &proxy.Proxy{ p.LastSeen = now
ID: proxyID, p.ConnectedAt = &now
ClusterAddress: clusterAddress, p.Status = "connected"
IPAddress: ipAddress, if err := s.db.Create(p).Error; err != nil {
LastSeen: now, log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
ConnectedAt: &now,
Status: "connected",
}
if err := s.db.Save(p).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
} }
} }

View File

@@ -284,7 +284,8 @@ type Store interface {
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool

View File

@@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
} }
// Close mocks base method. // Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error { func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2799,6 +2800,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
} }
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
ret0, _ := ret[0].(error)
return ret0
}
// DisconnectProxy indicates an expected call of DisconnectProxy.
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
}
// SaveProxyAccessToken mocks base method. // SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2995,17 +3010,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
} }
// UpdateProxyHeartbeat mocks base method. // UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress) ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat. // UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p)
} }
// UpdateService mocks base method. // UpdateService mocks base method.

View File

@@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
// testProxyManager is a mock implementation of proxy.Manager for testing. // testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{} type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
}
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error { func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error {
return nil
}
func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
return nil return nil
} }
@@ -656,3 +656,72 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID) assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
} }
} }
// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that
// when a proxy reconnects before the old stream's cleanup runs, the new
// connection is NOT removed by the stale defer.
func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-race"
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithCancel(context.Background())
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream1.Recv()
require.NoError(t, err)
}
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
"proxy should be registered after first connection")
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream2.Recv()
require.NoError(t, err)
}
cancel1()
time.Sleep(200 * time.Millisecond)
assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
"proxy should still be registered after old connection cleanup — old defer must not remove new connection")
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "rp-1",
AccountId: "test-account-1",
Domain: "app1.test.proxy.io",
}},
})
msg, err := stream2.Recv()
require.NoError(t, err, "new stream should still receive updates")
require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping")
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
}