mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-07 09:19:59 +00:00
Compare commits
7 Commits
dns-skip-f
...
fix/login-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aad7cc915f | ||
|
|
b19b7464ea | ||
|
|
cfb1b3fe31 | ||
|
|
3c28d29725 | ||
|
|
b7160fe7c0 | ||
|
|
f5bff93f01 | ||
|
|
43d4d54f40 |
93
client/server/login_overrides_test.go
Normal file
93
client/server/login_overrides_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
func TestPersistLoginOverrides(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialMgmtURL string
|
||||
initialPSK string
|
||||
newMgmtURL string
|
||||
newPSK *string
|
||||
wantMgmtURL string
|
||||
wantPSK string
|
||||
}{
|
||||
{
|
||||
name: "persist new management URL",
|
||||
initialMgmtURL: "https://old.example.com:33073",
|
||||
newMgmtURL: "https://new.example.com:33073",
|
||||
wantMgmtURL: "https://new.example.com:33073",
|
||||
},
|
||||
{
|
||||
name: "persist new pre-shared key",
|
||||
initialMgmtURL: "https://existing.example.com:33073",
|
||||
initialPSK: "old-key",
|
||||
newPSK: strPtr("new-key"),
|
||||
wantMgmtURL: "https://existing.example.com:33073",
|
||||
wantPSK: "new-key",
|
||||
},
|
||||
{
|
||||
name: "persist both",
|
||||
initialMgmtURL: "https://old.example.com:33073",
|
||||
initialPSK: "old-key",
|
||||
newMgmtURL: "https://new.example.com:33073",
|
||||
newPSK: strPtr("new-key"),
|
||||
wantMgmtURL: "https://new.example.com:33073",
|
||||
wantPSK: "new-key",
|
||||
},
|
||||
{
|
||||
name: "no inputs preserves existing",
|
||||
initialMgmtURL: "https://existing.example.com:33073",
|
||||
initialPSK: "existing-key",
|
||||
wantMgmtURL: "https://existing.example.com:33073",
|
||||
wantPSK: "existing-key",
|
||||
},
|
||||
{
|
||||
name: "empty PSK pointer is ignored",
|
||||
initialMgmtURL: "https://existing.example.com:33073",
|
||||
initialPSK: "existing-key",
|
||||
newPSK: strPtr(""),
|
||||
wantMgmtURL: "https://existing.example.com:33073",
|
||||
wantPSK: "existing-key",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
origDefault := profilemanager.DefaultConfigPath
|
||||
t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault })
|
||||
|
||||
dir := t.TempDir()
|
||||
profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json")
|
||||
|
||||
seed := profilemanager.ConfigInput{
|
||||
ConfigPath: profilemanager.DefaultConfigPath,
|
||||
ManagementURL: tt.initialMgmtURL,
|
||||
}
|
||||
if tt.initialPSK != "" {
|
||||
seed.PreSharedKey = strPtr(tt.initialPSK)
|
||||
}
|
||||
_, err := profilemanager.UpdateOrCreateConfig(seed)
|
||||
require.NoError(t, err, "seed config")
|
||||
|
||||
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
|
||||
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
|
||||
require.NoError(t, err, "persistLoginOverrides")
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath)
|
||||
require.NoError(t, err, "read back config")
|
||||
|
||||
require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL")
|
||||
require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -489,6 +489,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil {
|
||||
log.Errorf("failed to persist login overrides: %v", err)
|
||||
return nil, fmt.Errorf("persist login overrides: %w", err)
|
||||
}
|
||||
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
@@ -963,7 +968,33 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
|
||||
return &proto.LogoutResponse{}, nil
|
||||
}
|
||||
|
||||
// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
|
||||
// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the
|
||||
// active profile config so that subsequent reads pick them up. Empty/nil values are ignored.
|
||||
func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error {
|
||||
if preSharedKey != nil && *preSharedKey == "" {
|
||||
preSharedKey = nil
|
||||
}
|
||||
if managementURL == "" && preSharedKey == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("active profile file path: %w", err)
|
||||
}
|
||||
|
||||
input := profilemanager.ConfigInput{
|
||||
ConfigPath: cfgPath,
|
||||
ManagementURL: managementURL,
|
||||
PreSharedKey: preSharedKey,
|
||||
}
|
||||
if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil {
|
||||
return fmt.Errorf("update config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
|
||||
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) {
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
|
||||
@@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing connector in Dex storage.
|
||||
// It merges incoming updates with existing values to prevent data loss on partial updates.
|
||||
// It overlays user-mutable config fields (issuer, clientID, clientSecret,
|
||||
// redirectURI) onto the stored connector config, and updates the connector name
|
||||
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
|
||||
// partial updates preserve create-time defaults such as scopes, claimMapping,
|
||||
// and userIDKey.
|
||||
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
||||
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||
oldCfg, err := p.parseStorageConnector(old)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
|
||||
if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) {
|
||||
return storage.Connector{}, errors.New("connector type change not allowed")
|
||||
}
|
||||
|
||||
mergeConnectorConfig(cfg, oldCfg)
|
||||
|
||||
storageConn, err := p.buildStorageConnector(cfg)
|
||||
configData, err := overlayConnectorConfig(old.Config, cfg)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
|
||||
return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err)
|
||||
}
|
||||
return storageConn, nil
|
||||
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = old.Name
|
||||
}
|
||||
|
||||
return storage.Connector{
|
||||
ID: cfg.ID,
|
||||
Type: old.Type,
|
||||
Name: name,
|
||||
Config: configData,
|
||||
}, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update connector: %w", err)
|
||||
}
|
||||
@@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeConnectorConfig preserves existing values for empty fields in the update.
|
||||
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
|
||||
if cfg.ClientSecret == "" {
|
||||
cfg.ClientSecret = oldCfg.ClientSecret
|
||||
// overlayConnectorConfig writes only the user-mutable fields onto the existing
|
||||
// stored config, preserving every other field (scopes, claimMapping, userIDKey,
|
||||
// insecure flags, etc.). Empty fields on cfg leave the existing value alone.
|
||||
func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) {
|
||||
var m map[string]any
|
||||
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.RedirectURI == "" {
|
||||
cfg.RedirectURI = oldCfg.RedirectURI
|
||||
if cfg.Issuer != "" {
|
||||
m["issuer"] = cfg.Issuer
|
||||
}
|
||||
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
|
||||
cfg.Issuer = oldCfg.Issuer
|
||||
if cfg.ClientID != "" {
|
||||
m["clientID"] = cfg.ClientID
|
||||
}
|
||||
if cfg.ClientID == "" {
|
||||
cfg.ClientID = oldCfg.ClientID
|
||||
if cfg.ClientSecret != "" {
|
||||
m["clientSecret"] = cfg.ClientSecret
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
cfg.Name = oldCfg.Name
|
||||
if cfg.RedirectURI != "" {
|
||||
m["redirectURI"] = cfg.RedirectURI
|
||||
}
|
||||
return encodeConnectorConfig(m)
|
||||
}
|
||||
|
||||
// DeleteConnector removes a connector from Dex storage.
|
||||
@@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
oidcConfig["getUserInfo"] = true
|
||||
case "entra":
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
|
||||
// Entra issues sub as a per-app pairwise identifier that does not match
|
||||
// the stable Object ID.
|
||||
oidcConfig["userIDKey"] = "oid"
|
||||
case "okta":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
|
||||
205
idp/dex/connector_test.go
Normal file
205
idp/dex/connector_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestProvider(t *testing.T) (*Provider, func()) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &Provider{storage: s, logger: logger}, func() {
|
||||
_ = s.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
|
||||
cfg := &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
}
|
||||
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(data, &m))
|
||||
|
||||
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
|
||||
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
|
||||
}
|
||||
|
||||
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
|
||||
// ensures the Entra userIDKey override does not leak into other OIDC providers,
|
||||
// which already use a stable sub claim.
|
||||
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
|
||||
t.Run(typ, func(t *testing.T) {
|
||||
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(data, &m))
|
||||
_, ok := m["userIDKey"]
|
||||
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
created, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "old-secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "entra-test", created.ID)
|
||||
|
||||
// Rotate only the client secret.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "entra",
|
||||
ClientSecret: "new-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
|
||||
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
|
||||
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
|
||||
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
|
||||
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
|
||||
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
|
||||
}
|
||||
|
||||
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
// Seed a connector directly into storage without userIDKey
|
||||
preFixConfig, err := json.Marshal(map[string]any{
|
||||
"issuer": "https://login.microsoftonline.com/tid/v2.0",
|
||||
"clientID": "client-id",
|
||||
"clientSecret": "old-secret",
|
||||
"redirectURI": "https://example.com/oauth2/callback",
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"claimMapping": map[string]string{"email": "preferred_username"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
|
||||
ID: "entra-prefix",
|
||||
Type: "oidc",
|
||||
Name: "Entra",
|
||||
Config: preFixConfig,
|
||||
}))
|
||||
|
||||
// Rotate client secret via UpdateConnector.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-prefix",
|
||||
Type: "entra",
|
||||
ClientSecret: "new-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
|
||||
assert.Equal(t, "new-secret", m["clientSecret"])
|
||||
_, has := m["userIDKey"]
|
||||
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
|
||||
}
|
||||
|
||||
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to switch the connector to okta.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "okta",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "connector type change not allowed")
|
||||
|
||||
// stored connector type/config unchanged after the rejected update.
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "oidc", conn.Type)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
assert.Equal(t, "oid", m["userIDKey"])
|
||||
}
|
||||
|
||||
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/old/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/new/v2.0",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
|
||||
}
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -82,11 +84,40 @@ type ProxyServiceServer struct {
|
||||
// Store for PKCE verifiers
|
||||
pkceVerifierStore *PKCEVerifierStore
|
||||
|
||||
// tokenTTL is the lifetime of one-time tokens generated for proxy
|
||||
// authentication. Defaults to defaultProxyTokenTTL when zero.
|
||||
tokenTTL time.Duration
|
||||
|
||||
// snapshotBatchSize is the number of mappings per gRPC message during
|
||||
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
|
||||
snapshotBatchSize int
|
||||
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
const pkceVerifierTTL = 10 * time.Minute
|
||||
|
||||
const defaultProxyTokenTTL = 5 * time.Minute
|
||||
|
||||
const defaultSnapshotBatchSize = 500
|
||||
|
||||
func snapshotBatchSizeFromEnv() int {
|
||||
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultSnapshotBatchSize
|
||||
}
|
||||
|
||||
// proxyTokenTTL returns the configured token TTL or the default when unset.
|
||||
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
|
||||
if s.tokenTTL > 0 {
|
||||
return s.tokenTTL
|
||||
}
|
||||
return defaultProxyTokenTTL
|
||||
}
|
||||
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
@@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
cancel: cancel,
|
||||
}
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
@@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database with capabilities
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
@@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
s.connectedProxies.CompareAndDelete(proxyID, conn)
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
cancel()
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
@@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
go s.heartbeat(connCtx, proxyRecord)
|
||||
|
||||
select {
|
||||
@@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return err
|
||||
}
|
||||
|
||||
// Send mappings in batches to reduce per-message gRPC overhead while
|
||||
// staying well within the default 4 MB message size limit.
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, m := range mappings {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{m},
|
||||
InitialSyncComplete: i == len(mappings)-1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send proxy mapping: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"service": service.Name,
|
||||
"account": service.AccountID,
|
||||
}).WithError(err).Error("failed to generate auth token for snapshot")
|
||||
continue
|
||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
@@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
conn := value.(*proxyConnection)
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- resp:
|
||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
}
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||
conn.cancel()
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||
conn.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
// Returns nil if token generation fails (the proxy should be skipped).
|
||||
// Returns nil if token generation fails; the caller must disconnect the
|
||||
// proxy so it can resync via a fresh snapshot on reconnect.
|
||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, mapping := range update.Mapping {
|
||||
@@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||
return nil
|
||||
|
||||
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// recordingStream captures all messages sent via Send so tests can inspect
|
||||
// batching behaviour without a real gRPC transport.
|
||||
type recordingStream struct {
|
||||
grpc.ServerStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
}
|
||||
|
||||
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||
s.messages = append(s.messages, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *recordingStream) Context() context.Context { return context.Background() }
|
||||
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *recordingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *recordingStream) SendMsg(any) error { return nil }
|
||||
func (s *recordingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// makeServices creates n enabled services assigned to the given cluster.
|
||||
func makeServices(n int, cluster string) []*rpservice.Service {
|
||||
services := make([]*rpservice.Service, n)
|
||||
for i := range n {
|
||||
services[i] = &rpservice.Service{
|
||||
ID: fmt.Sprintf("svc-%d", i),
|
||||
AccountID: "acct-1",
|
||||
Name: fmt.Sprintf("svc-%d", i),
|
||||
Domain: fmt.Sprintf("svc-%d.example.com", i),
|
||||
ProxyCluster: cluster,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
|
||||
},
|
||||
}
|
||||
}
|
||||
return services
|
||||
}
|
||||
|
||||
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
|
||||
t.Helper()
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
|
||||
snapshotBatchSize: batchSize,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
return s
|
||||
}
|
||||
|
||||
func TestSendSnapshot_BatchesMappings(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 7 // 3 + 3 + 1
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshot(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect ceil(7/3) = 3 messages
|
||||
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
|
||||
|
||||
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
|
||||
|
||||
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
|
||||
|
||||
assert.Len(t, stream.messages[2].Mapping, 1)
|
||||
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
|
||||
|
||||
// Verify all service IDs are present exactly once
|
||||
seen := make(map[string]bool)
|
||||
for _, msg := range stream.messages {
|
||||
for _, m := range msg.Mapping {
|
||||
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
|
||||
seen[m.Id] = true
|
||||
}
|
||||
}
|
||||
assert.Len(t, seen, totalServices)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 6 // exactly 2 batches
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 2)
|
||||
|
||||
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||
assert.False(t, stream.messages[0].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||
assert.True(t, stream.messages[1].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_SingleBatch(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 100
|
||||
const totalServices = 5
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
|
||||
assert.Len(t, stream.messages[0].Mapping, totalServices)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||
|
||||
s := newSnapshotTestServer(t, 500)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
|
||||
assert.Empty(t, stream.messages[0].Mapping)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
@@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
||||
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
||||
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
capabilities: caps,
|
||||
sendChan: ch,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
|
||||
|
||||
@@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify one character in the secret part
|
||||
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
|
||||
replacement := "X"
|
||||
if plainToken[5] == 'X' {
|
||||
replacement = "Y"
|
||||
}
|
||||
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
|
||||
err = ValidateInviteToken(modifiedToken)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings from the snapshot - server sends each mapping individually
|
||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
for _, m := range msg.GetMapping() {
|
||||
mappingsByID[m.GetId()] = m
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive 2 mappings total
|
||||
@@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
mappings := make([]*proto.ProxyMapping, 0)
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive the 2 mappings matching the cluster
|
||||
@@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-reconnect"
|
||||
|
||||
// Helper to receive all mappings from a stream
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||
var mappings []*proto.ProxyMapping
|
||||
for i := 0; i < count; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
@@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstMappings := receiveMappings(stream1, 2)
|
||||
firstMappings := receiveMappings(stream1)
|
||||
cancel1()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondMappings := receiveMappings(stream2, 2)
|
||||
secondMappings := receiveMappings(stream2)
|
||||
|
||||
// Should receive the same mappings
|
||||
assert.Equal(t, len(firstMappings), len(secondMappings),
|
||||
@@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to receive and apply all mappings
|
||||
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
applyMappings(msg.GetMapping())
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
count := 0
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
count += len(msg.GetMapping())
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
@@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := stream1.Recv()
|
||||
for {
|
||||
msg, err := stream1.Recv()
|
||||
require.NoError(t, err)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||
@@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := stream2.Recv()
|
||||
for {
|
||||
msg, err := stream2.Recv()
|
||||
require.NoError(t, err)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cancel1()
|
||||
|
||||
@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
operation := func() error {
|
||||
s.Logger.Debug("connecting to management mapping stream")
|
||||
|
||||
initialSyncDone = false
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var snapshotIDs map[types.ServiceID]struct{}
|
||||
if !*initialSyncDone {
|
||||
snapshotIDs = make(map[types.ServiceID]struct{})
|
||||
}
|
||||
|
||||
for {
|
||||
// Check for context completion to gracefully shutdown.
|
||||
select {
|
||||
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
|
||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
if !*initialSyncDone {
|
||||
for _, m := range msg.GetMapping() {
|
||||
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
s.reconcileSnapshot(ctx, snapshotIDs)
|
||||
snapshotIDs = nil
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||
s.portMu.RLock()
|
||||
var stale []*proto.ProxyMapping
|
||||
for svcID, mapping := range s.lastMappings {
|
||||
if _, ok := snapshotIDs[svcID]; !ok {
|
||||
stale = append(stale, mapping)
|
||||
}
|
||||
}
|
||||
s.portMu.RUnlock()
|
||||
|
||||
for _, mapping := range stale {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Info("Removing stale mapping absent from snapshot")
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
|
||||
227
proxy/snapshot_reconcile_test.go
Normal file
227
proxy/snapshot_reconcile_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
|
||||
// so we can verify it without triggering removeMapping (which requires full
|
||||
// server wiring). This keeps the test focused on the detection algorithm.
|
||||
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
|
||||
var stale []types.ServiceID
|
||||
for svcID := range lastMappings {
|
||||
if _, ok := snapshotIDs[svcID]; !ok {
|
||||
stale = append(stale, svcID)
|
||||
}
|
||||
}
|
||||
return stale
|
||||
}
|
||||
|
||||
// TestStaleDetection_PartialOverlap verifies that only services absent from
|
||||
// the snapshot are flagged as stale.
|
||||
func TestStaleDetection_PartialOverlap(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
"svc-stale-a": {Id: "svc-stale-a"},
|
||||
"svc-stale-b": {Id: "svc-stale-b"},
|
||||
}
|
||||
snapshot := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
"svc-3": {}, // new service, not in local
|
||||
}
|
||||
|
||||
stale := collectStaleIDs(local, snapshot)
|
||||
assert.Len(t, stale, 2)
|
||||
staleSet := make(map[types.ServiceID]struct{})
|
||||
for _, id := range stale {
|
||||
staleSet[id] = struct{}{}
|
||||
}
|
||||
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
|
||||
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
|
||||
}
|
||||
|
||||
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
|
||||
func TestStaleDetection_AllStale(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
}
|
||||
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
|
||||
assert.Len(t, stale, 2)
|
||||
}
|
||||
|
||||
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
|
||||
func TestStaleDetection_NoneStale(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
}
|
||||
snapshot := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
}
|
||||
stale := collectStaleIDs(local, snapshot)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
|
||||
func TestStaleDetection_EmptyLocal(t *testing.T) {
|
||||
stale := collectStaleIDs(
|
||||
map[types.ServiceID]*proto.ProxyMapping{},
|
||||
map[types.ServiceID]struct{}{"svc-1": {}},
|
||||
)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
|
||||
// local mappings are present in the snapshot (removeMapping is never called).
|
||||
func TestReconcileSnapshot_NoStale(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
|
||||
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
|
||||
|
||||
snapshotIDs := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
}
|
||||
// This should not panic — no stale entries means removeMapping is never called.
|
||||
s.reconcileSnapshot(context.Background(), snapshotIDs)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
|
||||
}
|
||||
|
||||
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
|
||||
// no local mappings.
|
||||
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
|
||||
assert.Empty(t, s.lastMappings)
|
||||
}
|
||||
|
||||
// --- handleMappingStream tests for batched snapshot ID accumulation ---
|
||||
|
||||
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
|
||||
// marked done only after the final InitialSyncComplete message, even when
|
||||
// the snapshot arrives in multiple batches.
|
||||
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
|
||||
checker := health.NewChecker(nil, nil)
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
healthChecker: checker,
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // batch 1: no sync-complete
|
||||
{}, // batch 2: no sync-complete
|
||||
{InitialSyncComplete: true}, // batch 3: sync done
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "sync should be marked done after final batch")
|
||||
}
|
||||
|
||||
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
|
||||
// arriving after InitialSyncComplete do not trigger a second reconciliation.
|
||||
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// Simulate state left over from a previous sync.
|
||||
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
|
||||
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // post-sync empty message — must not reconcile
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := true // sync already completed in a previous stream
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2,
|
||||
"post-sync messages must not trigger reconciliation — all entries should survive")
|
||||
}
|
||||
|
||||
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
|
||||
// stream closes before sync completes, no reconciliation occurs.
|
||||
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
stream := &mockMappingStream{} // no messages → immediate EOF
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
|
||||
|
||||
_, hasStale := s.lastMappings["svc-stale"]
|
||||
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
|
||||
}
|
||||
|
||||
// mockErrRecvStream returns an error on the second Recv to verify
|
||||
// handleMappingStream returns without completing sync.
|
||||
type mockErrRecvStream struct {
|
||||
mockMappingStream
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &proto.GetMappingUpdateResponse{}, nil
|
||||
}
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, syncDone)
|
||||
|
||||
_, hasStale := s.lastMappings["svc-stale"]
|
||||
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
|
||||
}
|
||||
Reference in New Issue
Block a user