mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-05 16:46:39 +00:00
Compare commits
3 Commits
fix/login-
...
fix/debug-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cb25de9ea | ||
|
|
97db824929 | ||
|
|
77a0992dc2 |
@@ -58,6 +58,11 @@ linters:
|
|||||||
govet:
|
govet:
|
||||||
enable:
|
enable:
|
||||||
- nilness
|
- nilness
|
||||||
|
disable:
|
||||||
|
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
|
||||||
|
# directives but cannot perform the rewrite due to generic type
|
||||||
|
# parameter inference limitations in the Go inliner.
|
||||||
|
- inline
|
||||||
enable-all: false
|
enable-all: false
|
||||||
revive:
|
revive:
|
||||||
rules:
|
rules:
|
||||||
|
|||||||
@@ -607,6 +607,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||||
}
|
}
|
||||||
|
if g.internalConfig.DisableSSHAuth != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
|
||||||
|
}
|
||||||
|
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||||
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||||
@@ -633,6 +639,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
}
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||||
|
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addProf() (err error) {
|
func (g *BundleGenerator) addProf() (err error) {
|
||||||
|
|||||||
@@ -5,16 +5,21 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/configs"
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
|||||||
assert.Contains(t, anonNftables, "chain input {")
|
assert.Contains(t, anonNftables, "chain input {")
|
||||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
|
||||||
|
// profilemanager.Config is either rendered in the debug bundle or explicitly
|
||||||
|
// excluded. When a new field is added to Config, this test fails until the
|
||||||
|
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
|
||||||
|
// the excluded set with a justification.
|
||||||
|
func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||||
|
excluded := map[string]string{
|
||||||
|
"PrivateKey": "sensitive: WireGuard private key",
|
||||||
|
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||||
|
"SSHKey": "sensitive: SSH private key",
|
||||||
|
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||||
|
}
|
||||||
|
|
||||||
|
mURL, _ := url.Parse("https://api.example.com:443")
|
||||||
|
aURL, _ := url.Parse("https://admin.example.com:443")
|
||||||
|
bTrue := true
|
||||||
|
iVal := 42
|
||||||
|
cfg := &profilemanager.Config{
|
||||||
|
PrivateKey: "priv",
|
||||||
|
PreSharedKey: "psk",
|
||||||
|
ManagementURL: mURL,
|
||||||
|
AdminURL: aURL,
|
||||||
|
WgIface: "wt0",
|
||||||
|
WgPort: 51820,
|
||||||
|
NetworkMonitor: &bTrue,
|
||||||
|
IFaceBlackList: []string{"eth0"},
|
||||||
|
DisableIPv6Discovery: true,
|
||||||
|
RosenpassEnabled: true,
|
||||||
|
RosenpassPermissive: true,
|
||||||
|
ServerSSHAllowed: &bTrue,
|
||||||
|
EnableSSHRoot: &bTrue,
|
||||||
|
EnableSSHSFTP: &bTrue,
|
||||||
|
EnableSSHLocalPortForwarding: &bTrue,
|
||||||
|
EnableSSHRemotePortForwarding: &bTrue,
|
||||||
|
DisableSSHAuth: &bTrue,
|
||||||
|
SSHJWTCacheTTL: &iVal,
|
||||||
|
DisableClientRoutes: true,
|
||||||
|
DisableServerRoutes: true,
|
||||||
|
DisableDNS: true,
|
||||||
|
DisableFirewall: true,
|
||||||
|
BlockLANAccess: true,
|
||||||
|
BlockInbound: true,
|
||||||
|
DisableNotifications: &bTrue,
|
||||||
|
DNSLabels: domain.List{},
|
||||||
|
SSHKey: "sshkey",
|
||||||
|
NATExternalIPs: []string{"1.2.3.4"},
|
||||||
|
CustomDNSAddress: "1.1.1.1:53",
|
||||||
|
DisableAutoConnect: true,
|
||||||
|
DNSRouteInterval: 5 * time.Second,
|
||||||
|
ClientCertPath: "/tmp/cert",
|
||||||
|
ClientCertKeyPath: "/tmp/key",
|
||||||
|
LazyConnectionEnabled: true,
|
||||||
|
MTU: 1280,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, anonymize := range []bool{false, true} {
|
||||||
|
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
|
||||||
|
g := &BundleGenerator{
|
||||||
|
anonymizer: newAnonymizerForTest(),
|
||||||
|
internalConfig: cfg,
|
||||||
|
anonymize: anonymize,
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
g.addCommonConfigFields(&sb)
|
||||||
|
rendered := sb.String() + renderAddConfigSpecific(g)
|
||||||
|
|
||||||
|
val := reflect.ValueOf(cfg).Elem()
|
||||||
|
typ := val.Type()
|
||||||
|
var missing []string
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
name := typ.Field(i).Name
|
||||||
|
if _, ok := excluded[name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.Contains(rendered, name+":") {
|
||||||
|
missing = append(missing, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
|
||||||
|
"Either render the field in addCommonConfigFields/addConfig, "+
|
||||||
|
"or add it to the excluded map with a justification.", missing)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
|
||||||
|
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
|
||||||
|
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
|
||||||
|
// production shape without needing to write an actual zip.
|
||||||
|
func renderAddConfigSpecific(g *BundleGenerator) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
if g.anonymize {
|
||||||
|
if g.internalConfig.ManagementURL != nil {
|
||||||
|
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
|
||||||
|
}
|
||||||
|
if g.internalConfig.AdminURL != nil {
|
||||||
|
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("NATExternalIPs: x\n")
|
||||||
|
if g.internalConfig.CustomDNSAddress != "" {
|
||||||
|
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if g.internalConfig.ManagementURL != nil {
|
||||||
|
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
|
||||||
|
}
|
||||||
|
if g.internalConfig.AdminURL != nil {
|
||||||
|
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("NATExternalIPs: x\n")
|
||||||
|
if g.internalConfig.CustomDNSAddress != "" {
|
||||||
|
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAnonymizerForTest() *anonymize.Anonymizer {
|
||||||
|
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||||
|
}
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -309,8 +309,8 @@ require (
|
|||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||||
rsc.io/qr v0.2.0 // indirect
|
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
|
rsc.io/qr v0.2.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
|
|
||||||
// Manager defines the interface for proxy operations
|
// Manager defines the interface for proxy operations
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
||||||
Disconnect(ctx context.Context, proxyID string) error
|
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
||||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
Heartbeat(ctx context.Context, p *Proxy) error
|
||||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ import (
|
|||||||
// store defines the interface for proxy persistence operations
|
// store defines the interface for proxy persistence operations
|
||||||
type store interface {
|
type store interface {
|
||||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||||
|
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
@@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
|||||||
|
|
||||||
// Connect registers a new proxy connection in the database.
|
// Connect registers a new proxy connection in the database.
|
||||||
// capabilities may be nil for old proxies that do not report them.
|
// capabilities may be nil for old proxies that do not report them.
|
||||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
var caps proxy.Capabilities
|
var caps proxy.Capabilities
|
||||||
if capabilities != nil {
|
if capabilities != nil {
|
||||||
@@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
|||||||
}
|
}
|
||||||
p := &proxy.Proxy{
|
p := &proxy.Proxy{
|
||||||
ID: proxyID,
|
ID: proxyID,
|
||||||
|
SessionID: sessionID,
|
||||||
ClusterAddress: clusterAddress,
|
ClusterAddress: clusterAddress,
|
||||||
IPAddress: ipAddress,
|
IPAddress: ipAddress,
|
||||||
LastSeen: now,
|
LastSeen: now,
|
||||||
@@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
|||||||
|
|
||||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).WithFields(log.Fields{
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
"proxyID": proxyID,
|
"proxyID": proxyID,
|
||||||
|
"sessionID": sessionID,
|
||||||
"clusterAddress": clusterAddress,
|
"clusterAddress": clusterAddress,
|
||||||
"ipAddress": ipAddress,
|
"ipAddress": ipAddress,
|
||||||
}).Info("proxy connected")
|
}).Info("proxy connected")
|
||||||
|
|
||||||
return nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect marks a proxy as disconnected in the database
|
// Disconnect marks a proxy as disconnected in the database.
|
||||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||||
now := time.Now()
|
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||||
p := &proxy.Proxy{
|
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||||
ID: proxyID,
|
|
||||||
Status: "disconnected",
|
|
||||||
DisconnectedAt: &now,
|
|
||||||
LastSeen: now,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).WithFields(log.Fields{
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
"proxyID": proxyID,
|
"proxyID": proxyID,
|
||||||
|
"sessionID": sessionID,
|
||||||
}).Info("proxy disconnected")
|
}).Info("proxy disconnected")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat updates the proxy's last seen timestamp
|
// Heartbeat updates the proxy's last seen timestamp.
|
||||||
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
|
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID)
|
||||||
m.metrics.IncrementProxyHeartbeatCount()
|
m.metrics.IncrementProxyHeartbeatCount()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect mocks base method.
|
// Connect mocks base method.
|
||||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(*Proxy)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect indicates an expected call of Connect.
|
// Connect indicates an expected call of Connect.
|
||||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect mocks base method.
|
// Disconnect mocks base method.
|
||||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect indicates an expected call of Disconnect.
|
// Disconnect indicates an expected call of Disconnect.
|
||||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusterAddresses mocks base method.
|
// GetActiveClusterAddresses mocks base method.
|
||||||
@@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat mocks base method.
|
// Heartbeat mocks base method.
|
||||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
ret := m.ctrl.Call(m, "Heartbeat", ctx, p)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat indicates an expected call of Heartbeat.
|
// Heartbeat indicates an expected call of Heartbeat.
|
||||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockController is a mock of Controller interface.
|
// MockController is a mock of Controller interface.
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type Capabilities struct {
|
|||||||
// Proxy represents a reverse proxy instance
|
// Proxy represents a reverse proxy instance
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||||
|
SessionID string `gorm:"type:varchar(36)"`
|
||||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||||
IPAddress string `gorm:"type:varchar(45)"`
|
IPAddress string `gorm:"type:varchar(45)"`
|
||||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@@ -89,6 +90,7 @@ const pkceVerifierTTL = 10 * time.Minute
|
|||||||
// proxyConnection represents a connected proxy
|
// proxyConnection represents a connected proxy
|
||||||
type proxyConnection struct {
|
type proxyConnection struct {
|
||||||
proxyID string
|
proxyID string
|
||||||
|
sessionID string
|
||||||
address string
|
address string
|
||||||
capabilities *proto.ProxyCapabilities
|
capabilities *proto.ProxyCapabilities
|
||||||
stream proto.ProxyService_GetMappingUpdateServer
|
stream proto.ProxyService_GetMappingUpdateServer
|
||||||
@@ -166,9 +168,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionID := uuid.NewString()
|
||||||
|
|
||||||
|
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||||
|
oldConn := old.(*proxyConnection)
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"proxy_id": proxyID,
|
||||||
|
"old_session_id": oldConn.sessionID,
|
||||||
|
"new_session_id": sessionID,
|
||||||
|
}).Info("Superseding existing proxy connection")
|
||||||
|
oldConn.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithCancel(ctx)
|
connCtx, cancel := context.WithCancel(ctx)
|
||||||
conn := &proxyConnection{
|
conn := &proxyConnection{
|
||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
|
sessionID: sessionID,
|
||||||
address: proxyAddress,
|
address: proxyAddress,
|
||||||
capabilities: req.GetCapabilities(),
|
capabilities: req.GetCapabilities(),
|
||||||
stream: stream,
|
stream: stream,
|
||||||
@@ -191,9 +206,10 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
|
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||||
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||||
s.connectedProxies.Delete(proxyID)
|
s.connectedProxies.CompareAndDelete(proxyID, conn)
|
||||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||||
}
|
}
|
||||||
@@ -202,22 +218,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"proxy_id": proxyID,
|
"proxy_id": proxyID,
|
||||||
|
"session_id": sessionID,
|
||||||
"address": proxyAddress,
|
"address": proxyAddress,
|
||||||
"cluster_addr": proxyAddress,
|
"cluster_addr": proxyAddress,
|
||||||
"total_proxies": len(s.GetConnectedProxies()),
|
"total_proxies": len(s.GetConnectedProxies()),
|
||||||
}).Info("Proxy registered in cluster")
|
}).Info("Proxy registered in cluster")
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Delete(proxyID)
|
|
||||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||||
}
|
}
|
||||||
|
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||||
|
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
log.Infof("Proxy %s disconnected", proxyID)
|
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||||
@@ -227,29 +248,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
go s.sender(conn, errChan)
|
go s.sender(conn, errChan)
|
||||||
|
|
||||||
// Start heartbeat goroutine
|
go s.heartbeat(connCtx, proxyRecord)
|
||||||
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
|
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||||
case <-connCtx.Done():
|
case <-connCtx.Done():
|
||||||
|
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||||
return connCtx.Err()
|
return connCtx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
|
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
|
||||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5437,13 +5437,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
|
// DisconnectProxy marks a proxy as disconnected only if the session ID matches.
|
||||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
// This prevents a slow-to-close old session from overwriting a newer reconnection.
|
||||||
|
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||||
|
now := time.Now()
|
||||||
|
result := s.db.
|
||||||
|
Model(&proxy.Proxy{}).
|
||||||
|
Where("id = ? AND session_id = ?", proxyID, sessionID).
|
||||||
|
Updates(map[string]any{
|
||||||
|
"status": "disconnected",
|
||||||
|
"disconnected_at": now,
|
||||||
|
"last_seen": now,
|
||||||
|
})
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to disconnect proxy")
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
|
||||||
|
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
result := s.db.
|
result := s.db.
|
||||||
Model(&proxy.Proxy{}).
|
Model(&proxy.Proxy{}).
|
||||||
Where("id = ? AND status = ?", proxyID, "connected").
|
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
|
||||||
Update("last_seen", now)
|
Update("last_seen", now)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@@ -5452,17 +5474,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
|||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected == 0 {
|
if result.RowsAffected == 0 {
|
||||||
p := &proxy.Proxy{
|
p.LastSeen = now
|
||||||
ID: proxyID,
|
p.ConnectedAt = &now
|
||||||
ClusterAddress: clusterAddress,
|
p.Status = "connected"
|
||||||
IPAddress: ipAddress,
|
if err := s.db.Create(p).Error; err != nil {
|
||||||
LastSeen: now,
|
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
|
||||||
ConnectedAt: &now,
|
|
||||||
Status: "connected",
|
|
||||||
}
|
|
||||||
if err := s.db.Save(p).Error; err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
|
|
||||||
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -284,7 +284,8 @@ type Store interface {
|
|||||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||||
|
|
||||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||||
|
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close mocks base method.
|
// Close mocks base method.
|
||||||
func (m *MockStore) Close(ctx context.Context) error {
|
func (m *MockStore) Close(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -2799,6 +2800,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DisconnectProxy mocks base method.
|
||||||
|
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||||
|
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
// SaveProxyAccessToken mocks base method.
|
// SaveProxyAccessToken mocks base method.
|
||||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -2995,17 +3010,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProxyHeartbeat mocks base method.
|
// UpdateProxyHeartbeat mocks base method.
|
||||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
||||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateService mocks base method.
|
// UpdateService mocks base method.
|
||||||
|
|||||||
@@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
|||||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||||
type testProxyManager struct{}
|
type testProxyManager struct{}
|
||||||
|
|
||||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error {
|
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||||
|
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
|
func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error {
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -656,3 +656,72 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
|||||||
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that
|
||||||
|
// when a proxy reconnects before the old stream's cleanup runs, the new
|
||||||
|
// connection is NOT removed by the stale defer.
|
||||||
|
func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) {
|
||||||
|
setup := setupIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
clusterAddress := "test.proxy.io"
|
||||||
|
proxyID := "test-proxy-race"
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
|
||||||
|
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: proxyID,
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: clusterAddress,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
_, err := stream1.Recv()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||||
|
"proxy should be registered after first connection")
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: proxyID,
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: clusterAddress,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
_, err := stream2.Recv()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel1()
|
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||||
|
"proxy should still be registered after old connection cleanup — old defer must not remove new connection")
|
||||||
|
|
||||||
|
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: []*proto.ProxyMapping{{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||||
|
Id: "rp-1",
|
||||||
|
AccountId: "test-account-1",
|
||||||
|
Domain: "app1.test.proxy.io",
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
msg, err := stream2.Recv()
|
||||||
|
require.NoError(t, err, "new stream should still receive updates")
|
||||||
|
require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping")
|
||||||
|
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user