mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-11 03:09:55 +00:00
Compare commits
6 Commits
windows-dn
...
ssh-config
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad8765568b | ||
|
|
b19b7464ea | ||
|
|
cfb1b3fe31 | ||
|
|
3c28d29725 | ||
|
|
6e22e8a6fb | ||
|
|
9db7bec233 |
@@ -224,18 +224,31 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
|||||||
|
|
||||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||||
sshConfigPathTmp := sshConfigPath + ".tmp"
|
|
||||||
|
|
||||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
|
tmp, err := os.CreateTemp(m.sshConfigDir, m.sshConfigFile+".*.tmp")
|
||||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("create temp SSH config: %w", err)
|
||||||
|
}
|
||||||
|
tmpPath := tmp.Name()
|
||||||
|
defer func() {
|
||||||
|
if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
log.Debugf("remove temp SSH config %s: %v", tmpPath, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err := tmp.Close(); err != nil {
|
||||||
|
return fmt.Errorf("close temp SSH config %s: %w", tmpPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
|
if err := writeFileWithTimeout(tmpPath, []byte(sshConfig), 0644); err != nil {
|
||||||
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
|
return fmt.Errorf("write SSH config file %s: %w", tmpPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Rename(tmpPath, sshConfigPath); err != nil {
|
||||||
|
return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||||
|
|||||||
@@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConnector updates an existing connector in Dex storage.
|
// UpdateConnector updates an existing connector in Dex storage.
|
||||||
// It merges incoming updates with existing values to prevent data loss on partial updates.
|
// It overlays user-mutable config fields (issuer, clientID, clientSecret,
|
||||||
|
// redirectURI) onto the stored connector config, and updates the connector name
|
||||||
|
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
|
||||||
|
// partial updates preserve create-time defaults such as scopes, claimMapping,
|
||||||
|
// and userIDKey.
|
||||||
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
||||||
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||||
oldCfg, err := p.parseStorageConnector(old)
|
if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) {
|
||||||
if err != nil {
|
return storage.Connector{}, errors.New("connector type change not allowed")
|
||||||
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mergeConnectorConfig(cfg, oldCfg)
|
configData, err := overlayConnectorConfig(old.Config, cfg)
|
||||||
|
|
||||||
storageConn, err := p.buildStorageConnector(cfg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
|
return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err)
|
||||||
}
|
}
|
||||||
return storageConn, nil
|
|
||||||
|
name := cfg.Name
|
||||||
|
if name == "" {
|
||||||
|
name = old.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.Connector{
|
||||||
|
ID: cfg.ID,
|
||||||
|
Type: old.Type,
|
||||||
|
Name: name,
|
||||||
|
Config: configData,
|
||||||
|
}, nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return fmt.Errorf("failed to update connector: %w", err)
|
return fmt.Errorf("failed to update connector: %w", err)
|
||||||
}
|
}
|
||||||
@@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeConnectorConfig preserves existing values for empty fields in the update.
|
// overlayConnectorConfig writes only the user-mutable fields onto the existing
|
||||||
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
|
// stored config, preserving every other field (scopes, claimMapping, userIDKey,
|
||||||
if cfg.ClientSecret == "" {
|
// insecure flags, etc.). Empty fields on cfg leave the existing value alone.
|
||||||
cfg.ClientSecret = oldCfg.ClientSecret
|
func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) {
|
||||||
|
var m map[string]any
|
||||||
|
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
if cfg.RedirectURI == "" {
|
if cfg.Issuer != "" {
|
||||||
cfg.RedirectURI = oldCfg.RedirectURI
|
m["issuer"] = cfg.Issuer
|
||||||
}
|
}
|
||||||
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
|
if cfg.ClientID != "" {
|
||||||
cfg.Issuer = oldCfg.Issuer
|
m["clientID"] = cfg.ClientID
|
||||||
}
|
}
|
||||||
if cfg.ClientID == "" {
|
if cfg.ClientSecret != "" {
|
||||||
cfg.ClientID = oldCfg.ClientID
|
m["clientSecret"] = cfg.ClientSecret
|
||||||
}
|
}
|
||||||
if cfg.Name == "" {
|
if cfg.RedirectURI != "" {
|
||||||
cfg.Name = oldCfg.Name
|
m["redirectURI"] = cfg.RedirectURI
|
||||||
}
|
}
|
||||||
|
return encodeConnectorConfig(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteConnector removes a connector from Dex storage.
|
// DeleteConnector removes a connector from Dex storage.
|
||||||
@@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
|||||||
oidcConfig["getUserInfo"] = true
|
oidcConfig["getUserInfo"] = true
|
||||||
case "entra":
|
case "entra":
|
||||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||||
|
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
|
||||||
|
// Entra issues sub as a per-app pairwise identifier that does not match
|
||||||
|
// the stable Object ID.
|
||||||
|
oidcConfig["userIDKey"] = "oid"
|
||||||
case "okta":
|
case "okta":
|
||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "pocketid":
|
case "pocketid":
|
||||||
|
|||||||
205
idp/dex/connector_test.go
Normal file
205
idp/dex/connector_test.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package dex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/sql"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestProvider(t *testing.T) (*Provider, func()) {
|
||||||
|
t.Helper()
|
||||||
|
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return &Provider{storage: s, logger: logger}, func() {
|
||||||
|
_ = s.Close()
|
||||||
|
_ = os.RemoveAll(tmpDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
|
||||||
|
cfg := &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Name: "Entra",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
}
|
||||||
|
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(data, &m))
|
||||||
|
|
||||||
|
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
|
||||||
|
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
|
||||||
|
// ensures the Entra userIDKey override does not leak into other OIDC providers,
|
||||||
|
// which already use a stable sub claim.
|
||||||
|
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
|
||||||
|
t.Run(typ, func(t *testing.T) {
|
||||||
|
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
|
||||||
|
require.NoError(t, err)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(data, &m))
|
||||||
|
_, ok := m["userIDKey"]
|
||||||
|
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
p, cleanup := newTestProvider(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
created, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Name: "Entra",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "old-secret",
|
||||||
|
RedirectURI: "https://example.com/oauth2/callback",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "entra-test", created.ID)
|
||||||
|
|
||||||
|
// Rotate only the client secret.
|
||||||
|
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Type: "entra",
|
||||||
|
ClientSecret: "new-secret",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||||
|
|
||||||
|
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
|
||||||
|
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
|
||||||
|
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
|
||||||
|
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
|
||||||
|
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
p, cleanup := newTestProvider(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Seed a connector directly into storage without userIDKey
|
||||||
|
preFixConfig, err := json.Marshal(map[string]any{
|
||||||
|
"issuer": "https://login.microsoftonline.com/tid/v2.0",
|
||||||
|
"clientID": "client-id",
|
||||||
|
"clientSecret": "old-secret",
|
||||||
|
"redirectURI": "https://example.com/oauth2/callback",
|
||||||
|
"scopes": []string{"openid", "profile", "email"},
|
||||||
|
"claimMapping": map[string]string{"email": "preferred_username"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
|
||||||
|
ID: "entra-prefix",
|
||||||
|
Type: "oidc",
|
||||||
|
Name: "Entra",
|
||||||
|
Config: preFixConfig,
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Rotate client secret via UpdateConnector.
|
||||||
|
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-prefix",
|
||||||
|
Type: "entra",
|
||||||
|
ClientSecret: "new-secret",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
|
||||||
|
require.NoError(t, err)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||||
|
|
||||||
|
assert.Equal(t, "new-secret", m["clientSecret"])
|
||||||
|
_, has := m["userIDKey"]
|
||||||
|
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
p, cleanup := newTestProvider(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Name: "Entra",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "secret",
|
||||||
|
RedirectURI: "https://example.com/oauth2/callback",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Attempt to switch the connector to okta.
|
||||||
|
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Type: "okta",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "connector type change not allowed")
|
||||||
|
|
||||||
|
// stored connector type/config unchanged after the rejected update.
|
||||||
|
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "oidc", conn.Type)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||||
|
assert.Equal(t, "oid", m["userIDKey"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
p, cleanup := newTestProvider(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Name: "Entra",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/old/v2.0",
|
||||||
|
ClientID: "client-id",
|
||||||
|
ClientSecret: "secret",
|
||||||
|
RedirectURI: "https://example.com/oauth2/callback",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||||
|
ID: "entra-test",
|
||||||
|
Type: "entra",
|
||||||
|
Issuer: "https://login.microsoftonline.com/new/v2.0",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
var m map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||||
|
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
|
||||||
|
}
|
||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -82,11 +84,40 @@ type ProxyServiceServer struct {
|
|||||||
// Store for PKCE verifiers
|
// Store for PKCE verifiers
|
||||||
pkceVerifierStore *PKCEVerifierStore
|
pkceVerifierStore *PKCEVerifierStore
|
||||||
|
|
||||||
|
// tokenTTL is the lifetime of one-time tokens generated for proxy
|
||||||
|
// authentication. Defaults to defaultProxyTokenTTL when zero.
|
||||||
|
tokenTTL time.Duration
|
||||||
|
|
||||||
|
// snapshotBatchSize is the number of mappings per gRPC message during
|
||||||
|
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
|
||||||
|
snapshotBatchSize int
|
||||||
|
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
const pkceVerifierTTL = 10 * time.Minute
|
const pkceVerifierTTL = 10 * time.Minute
|
||||||
|
|
||||||
|
const defaultProxyTokenTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
const defaultSnapshotBatchSize = 500
|
||||||
|
|
||||||
|
func snapshotBatchSizeFromEnv() int {
|
||||||
|
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultSnapshotBatchSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyTokenTTL returns the configured token TTL or the default when unset.
|
||||||
|
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
|
||||||
|
if s.tokenTTL > 0 {
|
||||||
|
return s.tokenTTL
|
||||||
|
}
|
||||||
|
return defaultProxyTokenTTL
|
||||||
|
}
|
||||||
|
|
||||||
// proxyConnection represents a connected proxy
|
// proxyConnection represents a connected proxy
|
||||||
type proxyConnection struct {
|
type proxyConnection struct {
|
||||||
proxyID string
|
proxyID string
|
||||||
@@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
|||||||
peersManager: peersManager,
|
peersManager: peersManager,
|
||||||
usersManager: usersManager,
|
usersManager: usersManager,
|
||||||
proxyManager: proxyMgr,
|
proxyManager: proxyMgr,
|
||||||
|
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
go s.cleanupStaleProxies(ctx)
|
go s.cleanupStaleProxies(ctx)
|
||||||
@@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
|
||||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register proxy in database with capabilities
|
// Register proxy in database with capabilities
|
||||||
var caps *proxy.Capabilities
|
var caps *proxy.Capabilities
|
||||||
if c := req.GetCapabilities(); c != nil {
|
if c := req.GetCapabilities(); c != nil {
|
||||||
@@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||||
s.connectedProxies.CompareAndDelete(proxyID, conn)
|
cancel()
|
||||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
|
||||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
|
||||||
}
|
|
||||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||||
|
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||||
|
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||||
|
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||||
|
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
go s.sender(conn, errChan)
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"proxy_id": proxyID,
|
"proxy_id": proxyID,
|
||||||
"session_id": sessionID,
|
"session_id": sessionID,
|
||||||
@@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
|
||||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
|
||||||
go s.sender(conn, errChan)
|
|
||||||
|
|
||||||
go s.heartbeat(connCtx, proxyRecord)
|
go s.heartbeat(connCtx, proxyRecord)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send mappings in batches to reduce per-message gRPC overhead while
|
||||||
|
// staying well within the default 4 MB message size limit.
|
||||||
|
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||||
|
end := i + s.snapshotBatchSize
|
||||||
|
if end > len(mappings) {
|
||||||
|
end = len(mappings)
|
||||||
|
}
|
||||||
|
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: mappings[i:end],
|
||||||
|
InitialSyncComplete: end == len(mappings),
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("send snapshot batch: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(mappings) == 0 {
|
if len(mappings) == 0 {
|
||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||||
InitialSyncComplete: true,
|
InitialSyncComplete: true,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return fmt.Errorf("send snapshot completion: %w", err)
|
return fmt.Errorf("send snapshot completion: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, m := range mappings {
|
|
||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
|
||||||
Mapping: []*proto.ProxyMapping{m},
|
|
||||||
InitialSyncComplete: i == len(mappings)-1,
|
|
||||||
}); err != nil {
|
|
||||||
return fmt.Errorf("send proxy mapping: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
|
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||||
"service": service.Name,
|
|
||||||
"account": service.AccountID,
|
|
||||||
}).WithError(err).Error("failed to generate auth token for snapshot")
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||||
@@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
|||||||
conn := value.(*proxyConnection)
|
conn := value.(*proxyConnection)
|
||||||
resp := s.perProxyMessage(update, conn.proxyID)
|
resp := s.perProxyMessage(update, conn.proxyID)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
|
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||||
|
conn.cancel()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case conn.sendChan <- resp:
|
case conn.sendChan <- resp:
|
||||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||||
default:
|
default:
|
||||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||||
|
conn.cancel()
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
|||||||
}
|
}
|
||||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
|
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||||
|
conn.cancel()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case conn.sendChan <- msg:
|
case conn.sendChan <- msg:
|
||||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||||
default:
|
default:
|
||||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||||
|
conn.cancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
|||||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||||
// create/update operations. For delete operations the original mapping is
|
// create/update operations. For delete operations the original mapping is
|
||||||
// used unchanged because proxies do not need to authenticate for removal.
|
// used unchanged because proxies do not need to authenticate for removal.
|
||||||
// Returns nil if token generation fails (the proxy should be skipped).
|
// Returns nil if token generation fails; the caller must disconnect the
|
||||||
|
// proxy so it can resync via a fresh snapshot on reconnect.
|
||||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||||
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||||
for _, mapping := range update.Mapping {
|
for _, mapping := range update.Mapping {
|
||||||
@@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// recordingStream captures all messages sent via Send so tests can inspect
|
||||||
|
// batching behaviour without a real gRPC transport.
|
||||||
|
type recordingStream struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
messages []*proto.GetMappingUpdateResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||||
|
s.messages = append(s.messages, m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *recordingStream) Context() context.Context { return context.Background() }
|
||||||
|
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
|
||||||
|
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
|
||||||
|
func (s *recordingStream) SetTrailer(metadata.MD) {}
|
||||||
|
func (s *recordingStream) SendMsg(any) error { return nil }
|
||||||
|
func (s *recordingStream) RecvMsg(any) error { return nil }
|
||||||
|
|
||||||
|
// makeServices creates n enabled services assigned to the given cluster.
|
||||||
|
func makeServices(n int, cluster string) []*rpservice.Service {
|
||||||
|
services := make([]*rpservice.Service, n)
|
||||||
|
for i := range n {
|
||||||
|
services[i] = &rpservice.Service{
|
||||||
|
ID: fmt.Sprintf("svc-%d", i),
|
||||||
|
AccountID: "acct-1",
|
||||||
|
Name: fmt.Sprintf("svc-%d", i),
|
||||||
|
Domain: fmt.Sprintf("svc-%d.example.com", i),
|
||||||
|
ProxyCluster: cluster,
|
||||||
|
Enabled: true,
|
||||||
|
Targets: []*rpservice.Target{
|
||||||
|
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return services
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
|
||||||
|
t.Helper()
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
|
||||||
|
snapshotBatchSize: batchSize,
|
||||||
|
}
|
||||||
|
s.SetProxyController(newTestProxyController())
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendSnapshot_BatchesMappings(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
const batchSize = 3
|
||||||
|
const totalServices = 7 // 3 + 3 + 1
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, batchSize)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
|
||||||
|
stream := &recordingStream{}
|
||||||
|
conn := &proxyConnection{
|
||||||
|
proxyID: "proxy-a",
|
||||||
|
address: cluster,
|
||||||
|
stream: stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.sendSnapshot(context.Background(), conn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Expect ceil(7/3) = 3 messages
|
||||||
|
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||||
|
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||||
|
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[2].Mapping, 1)
|
||||||
|
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
|
||||||
|
|
||||||
|
// Verify all service IDs are present exactly once
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, msg := range stream.messages {
|
||||||
|
for _, m := range msg.Mapping {
|
||||||
|
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
|
||||||
|
seen[m.Id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Len(t, seen, totalServices)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
const batchSize = 3
|
||||||
|
const totalServices = 6 // exactly 2 batches
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, batchSize)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
|
||||||
|
stream := &recordingStream{}
|
||||||
|
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||||
|
|
||||||
|
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||||
|
require.Len(t, stream.messages, 2)
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||||
|
assert.False(t, stream.messages[0].InitialSyncComplete)
|
||||||
|
|
||||||
|
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||||
|
assert.True(t, stream.messages[1].InitialSyncComplete)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendSnapshot_SingleBatch(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
const batchSize = 100
|
||||||
|
const totalServices = 5
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, batchSize)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
|
||||||
|
stream := &recordingStream{}
|
||||||
|
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||||
|
|
||||||
|
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||||
|
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
|
||||||
|
assert.Len(t, stream.messages[0].Mapping, totalServices)
|
||||||
|
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, 500)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
|
||||||
|
stream := &recordingStream{}
|
||||||
|
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||||
|
|
||||||
|
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||||
|
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
|
||||||
|
assert.Empty(t, stream.messages[0].Mapping)
|
||||||
|
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||||
|
}
|
||||||
@@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
|||||||
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
||||||
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
||||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
conn := &proxyConnection{
|
conn := &proxyConnection{
|
||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
address: clusterAddr,
|
address: clusterAddr,
|
||||||
capabilities: caps,
|
capabilities: caps,
|
||||||
sendChan: ch,
|
sendChan: ch,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
|
||||||
|
|||||||
@@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
|
|||||||
_, plainToken, err := GenerateInviteToken()
|
_, plainToken, err := GenerateInviteToken()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Modify one character in the secret part
|
replacement := "X"
|
||||||
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
|
if plainToken[5] == 'X' {
|
||||||
|
replacement = "Y"
|
||||||
|
}
|
||||||
|
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
|
||||||
err = ValidateInviteToken(modifiedToken)
|
err = ValidateInviteToken(modifiedToken)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Receive all mappings from the snapshot - server sends each mapping individually
|
|
||||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
for _, m := range msg.GetMapping() {
|
for _, m := range msg.GetMapping() {
|
||||||
mappingsByID[m.GetId()] = m
|
mappingsByID[m.GetId()] = m
|
||||||
}
|
}
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should receive 2 mappings total
|
// Should receive 2 mappings total
|
||||||
@@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Receive all mappings - server sends each mapping individually
|
|
||||||
mappings := make([]*proto.ProxyMapping, 0)
|
mappings := make([]*proto.ProxyMapping, 0)
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mappings = append(mappings, msg.GetMapping()...)
|
mappings = append(mappings, msg.GetMapping()...)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should receive the 2 mappings matching the cluster
|
// Should receive the 2 mappings matching the cluster
|
||||||
@@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
|||||||
clusterAddress := "test.proxy.io"
|
clusterAddress := "test.proxy.io"
|
||||||
proxyID := "test-proxy-reconnect"
|
proxyID := "test-proxy-reconnect"
|
||||||
|
|
||||||
// Helper to receive all mappings from a stream
|
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
|
|
||||||
var mappings []*proto.ProxyMapping
|
var mappings []*proto.ProxyMapping
|
||||||
for i := 0; i < count; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mappings = append(mappings, msg.GetMapping()...)
|
mappings = append(mappings, msg.GetMapping()...)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return mappings
|
return mappings
|
||||||
}
|
}
|
||||||
@@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
firstMappings := receiveMappings(stream1, 2)
|
firstMappings := receiveMappings(stream1)
|
||||||
cancel1()
|
cancel1()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
secondMappings := receiveMappings(stream2, 2)
|
secondMappings := receiveMappings(stream2)
|
||||||
|
|
||||||
// Should receive the same mappings
|
// Should receive the same mappings
|
||||||
assert.Equal(t, len(firstMappings), len(secondMappings),
|
assert.Equal(t, len(firstMappings), len(secondMappings),
|
||||||
@@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to receive and apply all mappings
|
|
||||||
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
applyMappings(msg.GetMapping())
|
applyMappings(msg.GetMapping())
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Receive all mappings - server sends each mapping individually
|
|
||||||
count := 0
|
count := 0
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
msg, err := stream.Recv()
|
msg, err := stream.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
count += len(msg.GetMapping())
|
count += len(msg.GetMapping())
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
@@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
_, err := stream1.Recv()
|
msg, err := stream1.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||||
@@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for {
|
||||||
_, err := stream2.Recv()
|
msg, err := stream2.Recv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cancel1()
|
cancel1()
|
||||||
|
|||||||
@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
|||||||
operation := func() error {
|
operation := func() error {
|
||||||
s.Logger.Debug("connecting to management mapping stream")
|
s.Logger.Debug("connecting to management mapping stream")
|
||||||
|
|
||||||
|
initialSyncDone = false
|
||||||
|
|
||||||
if s.healthChecker != nil {
|
if s.healthChecker != nil {
|
||||||
s.healthChecker.SetManagementConnected(false)
|
s.healthChecker.SetManagementConnected(false)
|
||||||
}
|
}
|
||||||
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var snapshotIDs map[types.ServiceID]struct{}
|
||||||
|
if !*initialSyncDone {
|
||||||
|
snapshotIDs = make(map[types.ServiceID]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Check for context completion to gracefully shutdown.
|
// Check for context completion to gracefully shutdown.
|
||||||
select {
|
select {
|
||||||
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
|||||||
s.processMappings(ctx, msg.GetMapping())
|
s.processMappings(ctx, msg.GetMapping())
|
||||||
s.Logger.Debug("Processing mapping update completed")
|
s.Logger.Debug("Processing mapping update completed")
|
||||||
|
|
||||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
if !*initialSyncDone {
|
||||||
if s.healthChecker != nil {
|
for _, m := range msg.GetMapping() {
|
||||||
s.healthChecker.SetInitialSyncComplete()
|
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||||
|
}
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
s.reconcileSnapshot(ctx, snapshotIDs)
|
||||||
|
snapshotIDs = nil
|
||||||
|
if s.healthChecker != nil {
|
||||||
|
s.healthChecker.SetInitialSyncComplete()
|
||||||
|
}
|
||||||
|
*initialSyncDone = true
|
||||||
|
s.Logger.Info("Initial mapping sync complete")
|
||||||
}
|
}
|
||||||
*initialSyncDone = true
|
|
||||||
s.Logger.Info("Initial mapping sync complete")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||||
|
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||||
|
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||||
|
s.portMu.RLock()
|
||||||
|
var stale []*proto.ProxyMapping
|
||||||
|
for svcID, mapping := range s.lastMappings {
|
||||||
|
if _, ok := snapshotIDs[svcID]; !ok {
|
||||||
|
stale = append(stale, mapping)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.portMu.RUnlock()
|
||||||
|
|
||||||
|
for _, mapping := range stale {
|
||||||
|
s.Logger.WithFields(log.Fields{
|
||||||
|
"service_id": mapping.GetId(),
|
||||||
|
"domain": mapping.GetDomain(),
|
||||||
|
}).Info("Removing stale mapping absent from snapshot")
|
||||||
|
s.removeMapping(ctx, mapping)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||||
for _, mapping := range mappings {
|
for _, mapping := range mappings {
|
||||||
s.Logger.WithFields(log.Fields{
|
s.Logger.WithFields(log.Fields{
|
||||||
|
|||||||
227
proxy/snapshot_reconcile_test.go
Normal file
227
proxy/snapshot_reconcile_test.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
|
||||||
|
// so we can verify it without triggering removeMapping (which requires full
|
||||||
|
// server wiring). This keeps the test focused on the detection algorithm.
|
||||||
|
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
|
||||||
|
var stale []types.ServiceID
|
||||||
|
for svcID := range lastMappings {
|
||||||
|
if _, ok := snapshotIDs[svcID]; !ok {
|
||||||
|
stale = append(stale, svcID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stale
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStaleDetection_PartialOverlap verifies that only services absent from
|
||||||
|
// the snapshot are flagged as stale.
|
||||||
|
func TestStaleDetection_PartialOverlap(t *testing.T) {
|
||||||
|
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||||
|
"svc-1": {Id: "svc-1"},
|
||||||
|
"svc-2": {Id: "svc-2"},
|
||||||
|
"svc-stale-a": {Id: "svc-stale-a"},
|
||||||
|
"svc-stale-b": {Id: "svc-stale-b"},
|
||||||
|
}
|
||||||
|
snapshot := map[types.ServiceID]struct{}{
|
||||||
|
"svc-1": {},
|
||||||
|
"svc-2": {},
|
||||||
|
"svc-3": {}, // new service, not in local
|
||||||
|
}
|
||||||
|
|
||||||
|
stale := collectStaleIDs(local, snapshot)
|
||||||
|
assert.Len(t, stale, 2)
|
||||||
|
staleSet := make(map[types.ServiceID]struct{})
|
||||||
|
for _, id := range stale {
|
||||||
|
staleSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
|
||||||
|
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
|
||||||
|
func TestStaleDetection_AllStale(t *testing.T) {
|
||||||
|
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||||
|
"svc-1": {Id: "svc-1"},
|
||||||
|
"svc-2": {Id: "svc-2"},
|
||||||
|
}
|
||||||
|
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
|
||||||
|
assert.Len(t, stale, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
|
||||||
|
func TestStaleDetection_NoneStale(t *testing.T) {
|
||||||
|
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||||
|
"svc-1": {Id: "svc-1"},
|
||||||
|
"svc-2": {Id: "svc-2"},
|
||||||
|
}
|
||||||
|
snapshot := map[types.ServiceID]struct{}{
|
||||||
|
"svc-1": {},
|
||||||
|
"svc-2": {},
|
||||||
|
}
|
||||||
|
stale := collectStaleIDs(local, snapshot)
|
||||||
|
assert.Empty(t, stale)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
|
||||||
|
func TestStaleDetection_EmptyLocal(t *testing.T) {
|
||||||
|
stale := collectStaleIDs(
|
||||||
|
map[types.ServiceID]*proto.ProxyMapping{},
|
||||||
|
map[types.ServiceID]struct{}{"svc-1": {}},
|
||||||
|
)
|
||||||
|
assert.Empty(t, stale)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
|
||||||
|
// local mappings are present in the snapshot (removeMapping is never called).
|
||||||
|
func TestReconcileSnapshot_NoStale(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
|
||||||
|
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
|
||||||
|
|
||||||
|
snapshotIDs := map[types.ServiceID]struct{}{
|
||||||
|
"svc-1": {},
|
||||||
|
"svc-2": {},
|
||||||
|
}
|
||||||
|
// This should not panic — no stale entries means removeMapping is never called.
|
||||||
|
s.reconcileSnapshot(context.Background(), snapshotIDs)
|
||||||
|
|
||||||
|
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
|
||||||
|
// no local mappings.
|
||||||
|
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
|
||||||
|
assert.Empty(t, s.lastMappings)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- handleMappingStream tests for batched snapshot ID accumulation ---
|
||||||
|
|
||||||
|
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
|
||||||
|
// marked done only after the final InitialSyncComplete message, even when
|
||||||
|
// the snapshot arrives in multiple batches.
|
||||||
|
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
|
||||||
|
checker := health.NewChecker(nil, nil)
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
healthChecker: checker,
|
||||||
|
routerReady: closedChan(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := &mockMappingStream{
|
||||||
|
messages: []*proto.GetMappingUpdateResponse{
|
||||||
|
{}, // batch 1: no sync-complete
|
||||||
|
{}, // batch 2: no sync-complete
|
||||||
|
{InitialSyncComplete: true}, // batch 3: sync done
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
syncDone := false
|
||||||
|
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, syncDone, "sync should be marked done after final batch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
|
||||||
|
// arriving after InitialSyncComplete do not trigger a second reconciliation.
|
||||||
|
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
routerReady: closedChan(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate state left over from a previous sync.
|
||||||
|
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
|
||||||
|
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
|
||||||
|
|
||||||
|
stream := &mockMappingStream{
|
||||||
|
messages: []*proto.GetMappingUpdateResponse{
|
||||||
|
{}, // post-sync empty message — must not reconcile
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
syncDone := true // sync already completed in a previous stream
|
||||||
|
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, s.lastMappings, 2,
|
||||||
|
"post-sync messages must not trigger reconciliation — all entries should survive")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
|
||||||
|
// stream closes before sync completes, no reconciliation occurs.
|
||||||
|
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
routerReady: closedChan(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||||
|
|
||||||
|
stream := &mockMappingStream{} // no messages → immediate EOF
|
||||||
|
|
||||||
|
syncDone := false
|
||||||
|
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
|
||||||
|
|
||||||
|
_, hasStale := s.lastMappings["svc-stale"]
|
||||||
|
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockErrRecvStream returns an error on the second Recv to verify
|
||||||
|
// handleMappingStream returns without completing sync.
|
||||||
|
type mockErrRecvStream struct {
|
||||||
|
mockMappingStream
|
||||||
|
calls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||||
|
m.calls++
|
||||||
|
if m.calls == 1 {
|
||||||
|
return &proto.GetMappingUpdateResponse{}, nil
|
||||||
|
}
|
||||||
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
routerReady: closedChan(),
|
||||||
|
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||||
|
|
||||||
|
syncDone := false
|
||||||
|
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.False(t, syncDone)
|
||||||
|
|
||||||
|
_, hasStale := s.lastMappings["svc-stale"]
|
||||||
|
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user