mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 18:26:41 +00:00
Compare commits
15 Commits
feature/dy
...
cached-ser
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53deabbdb4 | ||
|
|
ac3fe4343b | ||
|
|
a4ae160993 | ||
|
|
3ac4263257 | ||
|
|
dc86c9655d | ||
|
|
66494d61af | ||
|
|
46446acd30 | ||
|
|
3eb1298cb4 | ||
|
|
93391fc68f | ||
|
|
48c080b861 | ||
|
|
3716838c25 | ||
|
|
5d58000dbd | ||
|
|
8430b06f2a | ||
|
|
5a89e6621b | ||
|
|
3f4ef0031b |
5
.github/workflows/golang-test-linux.yml
vendored
5
.github/workflows/golang-test-linux.yml
vendored
@@ -426,8 +426,11 @@ jobs:
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Generate current sync wire fixtures
|
||||
run: go run ./management/server/testdata/sync_request_wire/generate.go
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
|
||||
@@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -570,7 +570,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start()
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
|
||||
@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
forceRelay := IsForceRelayed()
|
||||
if !forceRelay {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
if !forceRelay {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.Close()
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
conn.dumpState.RemoteCandidate()
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
||||
//
|
||||
// The result is a tri-state:
|
||||
// - ConnStatusConnected: all available transports are up
|
||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
||||
// - ConnStatusDisconnected: no working transport
|
||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
defer func() {
|
||||
if !connected {
|
||||
if status == guard.ConnStatusDisconnected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
iceWorkerCreated := conn.workerICE != nil
|
||||
|
||||
var iceInProgress bool
|
||||
if iceWorkerCreated {
|
||||
iceInProgress = conn.workerICE.InProgress()
|
||||
}
|
||||
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return evalConnStatus(connStatusInputs{
|
||||
forceRelay: IsForceRelayed(),
|
||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
||||
iceWorkerCreated: iceWorkerCreated,
|
||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
||||
iceInProgress: iceInProgress,
|
||||
})
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
||||
|
||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
||||
if in.forceRelay {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
||||
// relay is the only possible transport.
|
||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
||||
|
||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
||||
|
||||
switch {
|
||||
case iceUp && relayOK:
|
||||
return guard.ConnStatusConnected
|
||||
case relayUsedAndUp:
|
||||
// Relay is up but ICE is down — partially connected.
|
||||
return guard.ConnStatusPartiallyConnected
|
||||
default:
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
}
|
||||
|
||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
||||
if connected {
|
||||
return guard.ConnStatusConnected
|
||||
}
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
|
||||
@@ -13,6 +13,20 @@ const (
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
||||
// without constructing full Worker/Handshaker objects.
|
||||
type connStatusInputs struct {
|
||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
201
client/internal/peer/conn_status_eval_test.go
Normal file
201
client/internal/peer/conn_status_eval_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
)
|
||||
|
||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "force relay, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: false,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE worker not yet created, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: false,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer does not use relay",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: false,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
||||
base := connStatusInputs{
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutator func(*connStatusInputs)
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "ICE connected, relay connected, peer uses relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE connected, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE InProgress only, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up, peer uses relay -> partial",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusPartiallyConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
||||
// falls into default: Disconnected.
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
in := base
|
||||
tc.mutator(&in)
|
||||
if got := evalConnStatus(in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
func IsForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -8,7 +8,19 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type isConnectedFunc func() bool
|
||||
// ConnStatus represents the connection state as seen by the guard.
|
||||
type ConnStatus int
|
||||
|
||||
const (
|
||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
||||
ConnStatusDisconnected ConnStatus = iota
|
||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
||||
ConnStatusPartiallyConnected
|
||||
// ConnStatusConnected means all required connections are established.
|
||||
ConnStatusConnected
|
||||
)
|
||||
|
||||
type connStatusFunc func() ConnStatus
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
log *log.Entry
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
isConnectedOnAllWay connStatusFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan struct{}
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
log: log,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
||||
//
|
||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
||||
// - Connected: no action, the peer is fully reachable.
|
||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
||||
//
|
||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
iceState := &iceRetryState{log: g.log}
|
||||
defer iceState.reset()
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
case <-tickerChannel:
|
||||
switch g.isConnectedOnAllWay() {
|
||||
case ConnStatusConnected:
|
||||
// all good, nothing to do
|
||||
case ConnStatusDisconnected:
|
||||
callback()
|
||||
case ConnStatusPartiallyConnected:
|
||||
if iceState.shouldRetry() {
|
||||
callback()
|
||||
} else {
|
||||
iceState.enterHourlyMode()
|
||||
ticker.Stop()
|
||||
tickerChannel = iceState.hourlyC()
|
||||
}
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
callback()
|
||||
}
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
|
||||
61
client/internal/peer/guard/ice_retry_state.go
Normal file
61
client/internal/peer/guard/ice_retry_state.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
||||
maxICERetries = 3
|
||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
||||
iceRetryInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
||||
type iceRetryState struct {
|
||||
log *log.Entry
|
||||
retries int
|
||||
hourly *time.Ticker
|
||||
}
|
||||
|
||||
func (s *iceRetryState) reset() {
|
||||
s.retries = 0
|
||||
if s.hourly != nil {
|
||||
s.hourly.Stop()
|
||||
s.hourly = nil
|
||||
}
|
||||
}
|
||||
|
||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
||||
func (s *iceRetryState) shouldRetry() bool {
|
||||
if s.hourly != nil {
|
||||
s.log.Debugf("hourly ICE retry attempt")
|
||||
return true
|
||||
}
|
||||
|
||||
s.retries++
|
||||
if s.retries <= maxICERetries {
|
||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
||||
func (s *iceRetryState) enterHourlyMode() {
|
||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
||||
s.hourly = time.NewTicker(iceRetryInterval)
|
||||
}
|
||||
|
||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
||||
if s.hourly == nil {
|
||||
return nil
|
||||
}
|
||||
return s.hourly.C
|
||||
}
|
||||
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTestRetryState() *iceRetryState {
|
||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
||||
}
|
||||
|
||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 0; i < maxICERetries; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
if s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if s.hourlyC() == nil {
|
||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
|
||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
s.enterHourlyMode()
|
||||
|
||||
s.reset()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
||||
}
|
||||
if s.retries != 0 {
|
||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
||||
}
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.reset()
|
||||
s.reset() // second call must not panic or re-stop a nil ticker
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC non-nil after double reset")
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start() {
|
||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
if !disableICEMonitor {
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -43,6 +44,10 @@ type OfferAnswer struct {
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
|
||||
func (o *OfferAnswer) hasICECredentials() bool {
|
||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
@@ -59,6 +64,10 @@ type Handshaker struct {
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
||||
remoteICESupported atomic.Bool
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
@@ -66,7 +75,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
return &Handshaker{
|
||||
h := &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
// assume remote supports ICE until we learn otherwise from received offers
|
||||
h.remoteICESupported.Store(ice != nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handshaker) RemoteICESupported() bool {
|
||||
return h.remoteICESupported.Load()
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
@@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error {
|
||||
}
|
||||
|
||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
SessionID: &sid,
|
||||
}
|
||||
|
||||
if h.ice != nil && h.RemoteICESupported() {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
||||
answer.SessionID = &sid
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
@@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
||||
hasICE := offer.hasICECredentials()
|
||||
prev := h.remoteICESupported.Swap(hasICE)
|
||||
if prev != hasICE {
|
||||
if hasICE {
|
||||
h.log.Infof("remote peer started sending ICE credentials")
|
||||
} else {
|
||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
||||
if h.ice != nil {
|
||||
h.ice.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
var sessionIDBytes []byte
|
||||
if offerAnswer.SessionID != nil {
|
||||
var err error
|
||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
|
||||
@@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -163,7 +163,9 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
|
||||
peerSerialCache := nbgrpc.NewPeerSerialCache(context.Background(), s.CacheStore(), nbgrpc.DefaultPeerSerialCacheTTL)
|
||||
fastPathFlag := nbgrpc.RunFastPathFlagRoutine(context.Background(), s.CacheStore(), nbgrpc.DefaultFastPathFlagInterval, nbgrpc.DefaultFastPathFlagKey)
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), peerSerialCache, fastPathFlag)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
|
||||
131
management/internals/shared/grpc/fast_path_flag.go
Normal file
131
management/internals/shared/grpc/fast_path_flag.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultFastPathFlagInterval is the default poll interval for the Sync
|
||||
// fast-path feature flag. Kept lower than the log-level overrider because
|
||||
// operators will want the toggle to propagate quickly during rollout.
|
||||
DefaultFastPathFlagInterval = 1 * time.Minute
|
||||
|
||||
// DefaultFastPathFlagKey is the cache key polled by RunFastPathFlagRoutine
|
||||
// when the caller does not provide an override.
|
||||
DefaultFastPathFlagKey = "peerSyncFastPath"
|
||||
)
|
||||
|
||||
// FastPathFlag exposes the current on/off state of the Sync fast path. The
|
||||
// zero value and a nil receiver both report disabled, so callers can always
|
||||
// treat the flag as a non-nil gate without an additional nil check.
|
||||
type FastPathFlag struct {
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// NewFastPathFlag returns a FastPathFlag whose state is set to the given
|
||||
// value. Callers that need the runtime toggle should use
|
||||
// RunFastPathFlagRoutine instead; this constructor is meant for tests and
|
||||
// for consumers that want to force the flag on or off.
|
||||
func NewFastPathFlag(enabled bool) *FastPathFlag {
|
||||
f := &FastPathFlag{}
|
||||
f.setEnabled(enabled)
|
||||
return f
|
||||
}
|
||||
|
||||
// Enabled reports whether the Sync fast path is currently enabled for this
|
||||
// replica. A nil receiver reports false so a disabled build or test can pass
|
||||
// a nil flag and skip the fast path entirely.
|
||||
func (f *FastPathFlag) Enabled() bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
return f.enabled.Load()
|
||||
}
|
||||
|
||||
func (f *FastPathFlag) setEnabled(v bool) {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
f.enabled.Store(v)
|
||||
}
|
||||
|
||||
// RunFastPathFlagRoutine starts a background goroutine that polls the shared
|
||||
// cache store for the Sync fast-path feature flag and updates the returned
|
||||
// FastPathFlag accordingly. When cacheStore is nil the routine returns a
|
||||
// handle that stays permanently disabled, so every Sync falls back to the
|
||||
// full network map path.
|
||||
//
|
||||
// The shared store is Redis-backed when NB_CACHE_REDIS_ADDRESS is set (so the
|
||||
// flag is toggled cluster-wide by writing the key in Redis) and falls back to
|
||||
// an in-process gocache otherwise, which is enough for single-replica dev and
|
||||
// test setups.
|
||||
//
|
||||
// The routine fails closed: any store read error (other than a plain "key not
|
||||
// found" miss) disables the flag until Redis confirms it is enabled again.
|
||||
func RunFastPathFlagRoutine(ctx context.Context, cacheStore store.StoreInterface, interval time.Duration, flagKey string) *FastPathFlag {
|
||||
flag := &FastPathFlag{}
|
||||
|
||||
if cacheStore == nil {
|
||||
log.Infof("Shared cache store not provided. Sync fast path disabled")
|
||||
return flag
|
||||
}
|
||||
|
||||
if flagKey == "" {
|
||||
flagKey = DefaultFastPathFlagKey
|
||||
}
|
||||
|
||||
flagCache := cache.New[string](cacheStore)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
refresh := func() {
|
||||
getCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
value, err := flagCache.Get(getCtx, flagKey)
|
||||
if err != nil {
|
||||
var notFound *store.NotFound
|
||||
if !errors.As(err, ¬Found) {
|
||||
log.Errorf("Sync fast-path flag refresh: %v; disabling fast path", err)
|
||||
}
|
||||
flag.setEnabled(false)
|
||||
return
|
||||
}
|
||||
flag.setEnabled(parseFastPathFlag(value))
|
||||
}
|
||||
|
||||
refresh()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Stopping Sync fast-path flag routine")
|
||||
return
|
||||
case <-ticker.C:
|
||||
refresh()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return flag
|
||||
}
|
||||
|
||||
// parseFastPathFlag accepts "1" or "true" (any casing, surrounding whitespace
|
||||
// tolerated) as enabled and treats every other value as disabled.
|
||||
func parseFastPathFlag(value string) bool {
|
||||
v := strings.TrimSpace(value)
|
||||
if v == "1" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(v, "true")
|
||||
}
|
||||
176
management/internals/shared/grpc/fast_path_flag_test.go
Normal file
176
management/internals/shared/grpc/fast_path_flag_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
gocache_store "github.com/eko/gocache/store/go_cache/v4"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseFastPathFlag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want bool
|
||||
}{
|
||||
{"one", "1", true},
|
||||
{"true lowercase", "true", true},
|
||||
{"true uppercase", "TRUE", true},
|
||||
{"true mixed case", "True", true},
|
||||
{"true with whitespace", " true ", true},
|
||||
{"zero", "0", false},
|
||||
{"false", "false", false},
|
||||
{"empty", "", false},
|
||||
{"yes", "yes", false},
|
||||
{"garbage", "garbage", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, parseFastPathFlag(tt.value), "parseFastPathFlag(%q)", tt.value)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFastPathFlag_EnabledDefaultsFalse(t *testing.T) {
|
||||
flag := &FastPathFlag{}
|
||||
assert.False(t, flag.Enabled(), "zero value flag should report disabled")
|
||||
}
|
||||
|
||||
func TestFastPathFlag_NilSafeEnabled(t *testing.T) {
|
||||
var flag *FastPathFlag
|
||||
assert.False(t, flag.Enabled(), "nil flag should report disabled without panicking")
|
||||
}
|
||||
|
||||
func TestFastPathFlag_SetEnabled(t *testing.T) {
|
||||
flag := &FastPathFlag{}
|
||||
flag.setEnabled(true)
|
||||
assert.True(t, flag.Enabled(), "flag should report enabled after setEnabled(true)")
|
||||
flag.setEnabled(false)
|
||||
assert.False(t, flag.Enabled(), "flag should report disabled after setEnabled(false)")
|
||||
}
|
||||
|
||||
func TestNewFastPathFlag(t *testing.T) {
|
||||
assert.True(t, NewFastPathFlag(true).Enabled(), "NewFastPathFlag(true) should report enabled")
|
||||
assert.False(t, NewFastPathFlag(false).Enabled(), "NewFastPathFlag(false) should report disabled")
|
||||
}
|
||||
|
||||
func TestRunFastPathFlagRoutine_NilStoreStaysDisabled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
flag := RunFastPathFlagRoutine(ctx, nil, 50*time.Millisecond, "peerSyncFastPath")
|
||||
require.NotNil(t, flag, "RunFastPathFlagRoutine should always return a non-nil flag")
|
||||
assert.False(t, flag.Enabled(), "flag should stay disabled when no cache store is provided")
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
assert.False(t, flag.Enabled(), "flag should remain disabled after wait when no cache store is provided")
|
||||
}
|
||||
|
||||
func TestRunFastPathFlagRoutine_ReadsFlagFromStore(t *testing.T) {
|
||||
cacheStore := newFastPathTestStore(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
flag := RunFastPathFlagRoutine(ctx, cacheStore, 50*time.Millisecond, "peerSyncFastPath")
|
||||
require.NotNil(t, flag)
|
||||
assert.False(t, flag.Enabled(), "flag should start disabled when the key is missing")
|
||||
|
||||
require.NoError(t, cacheStore.Set(ctx, "peerSyncFastPath", "1"), "seed flag=1 into shared store")
|
||||
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "flag should flip enabled after the key is set to 1")
|
||||
|
||||
require.NoError(t, cacheStore.Set(ctx, "peerSyncFastPath", "0"), "override flag=0 into shared store")
|
||||
assert.Eventually(t, func() bool {
|
||||
return !flag.Enabled()
|
||||
}, 2*time.Second, 25*time.Millisecond, "flag should flip disabled after the key is set to 0")
|
||||
|
||||
require.NoError(t, cacheStore.Delete(ctx, "peerSyncFastPath"), "remove flag key")
|
||||
assert.Eventually(t, func() bool {
|
||||
return !flag.Enabled()
|
||||
}, 2*time.Second, 25*time.Millisecond, "flag should stay disabled after the key is deleted")
|
||||
|
||||
require.NoError(t, cacheStore.Set(ctx, "peerSyncFastPath", "true"), "enable via string true")
|
||||
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "flag should accept \"true\" as enabled")
|
||||
}
|
||||
|
||||
func TestRunFastPathFlagRoutine_MissingKeyKeepsDisabled(t *testing.T) {
|
||||
cacheStore := newFastPathTestStore(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
flag := RunFastPathFlagRoutine(ctx, cacheStore, 50*time.Millisecond, "peerSyncFastPathAbsent")
|
||||
require.NotNil(t, flag)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
assert.False(t, flag.Enabled(), "flag should stay disabled when the key is missing from the store")
|
||||
}
|
||||
|
||||
func TestRunFastPathFlagRoutine_DefaultKeyUsedWhenEmpty(t *testing.T) {
|
||||
cacheStore := newFastPathTestStore(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
require.NoError(t, cacheStore.Set(ctx, DefaultFastPathFlagKey, "1"), "seed default key")
|
||||
|
||||
flag := RunFastPathFlagRoutine(ctx, cacheStore, 50*time.Millisecond, "")
|
||||
require.NotNil(t, flag)
|
||||
|
||||
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "empty flagKey should fall back to DefaultFastPathFlagKey")
|
||||
}
|
||||
|
||||
func newFastPathTestStore(t *testing.T) store.StoreInterface {
|
||||
t.Helper()
|
||||
return gocache_store.NewGoCache(gocache.New(5*time.Minute, 10*time.Minute))
|
||||
}
|
||||
|
||||
func TestRunFastPathFlagRoutine_FailsClosedOnReadError(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
s := &flakyStore{
|
||||
StoreInterface: newFastPathTestStore(t),
|
||||
}
|
||||
require.NoError(t, s.Set(ctx, "peerSyncFastPath", "1"), "seed flag enabled")
|
||||
|
||||
flag := RunFastPathFlagRoutine(ctx, s, 50*time.Millisecond, "peerSyncFastPath")
|
||||
require.NotNil(t, flag)
|
||||
|
||||
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "flag should flip enabled while store reads succeed")
|
||||
|
||||
s.setGetError(errors.New("simulated transient store failure"))
|
||||
assert.Eventually(t, func() bool {
|
||||
return !flag.Enabled()
|
||||
}, 2*time.Second, 25*time.Millisecond, "flag should flip disabled on store read error (fail-closed)")
|
||||
|
||||
s.setGetError(nil)
|
||||
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "flag should recover once the store read succeeds again")
|
||||
}
|
||||
|
||||
// flakyStore wraps a real store and lets tests inject a transient Get error
|
||||
// without affecting Set/Delete. Used to exercise fail-closed behaviour.
|
||||
type flakyStore struct {
|
||||
store.StoreInterface
|
||||
getErr atomic.Pointer[error]
|
||||
}
|
||||
|
||||
func (f *flakyStore) Get(ctx context.Context, key any) (any, error) {
|
||||
if errPtr := f.getErr.Load(); errPtr != nil && *errPtr != nil {
|
||||
return nil, *errPtr
|
||||
}
|
||||
return f.StoreInterface.Get(ctx, key)
|
||||
}
|
||||
|
||||
func (f *flakyStore) setGetError(err error) {
|
||||
if err == nil {
|
||||
f.getErr.Store(nil)
|
||||
return
|
||||
}
|
||||
f.getErr.Store(&err)
|
||||
}
|
||||
82
management/internals/shared/grpc/peer_serial_cache.go
Normal file
82
management/internals/shared/grpc/peer_serial_cache.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
peerSerialCacheKeyPrefix = "peer-sync:"
|
||||
|
||||
// DefaultPeerSerialCacheTTL bounds how long a cached serial survives. If the
|
||||
// cache write on a full-map send ever drops, entries naturally expire and
|
||||
// the next Sync falls back to the full path, re-priming the cache.
|
||||
DefaultPeerSerialCacheTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// PeerSerialCache records the NetworkMap serial and meta hash last delivered to
|
||||
// each peer on Sync. Lookups are used to skip full network map computation when
|
||||
// the peer already has the latest state. Backed by the shared cache store so
|
||||
// entries survive management replicas sharing a Redis instance.
|
||||
type PeerSerialCache struct {
|
||||
cache *cache.Cache[string]
|
||||
ctx context.Context
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewPeerSerialCache creates a cache wrapper bound to the shared cache store.
|
||||
// The ttl is applied to every Set call; entries older than ttl are treated as
|
||||
// misses so the server eventually converges to delivering a full map even if
|
||||
// an earlier Set was lost.
|
||||
func NewPeerSerialCache(ctx context.Context, cacheStore store.StoreInterface, ttl time.Duration) *PeerSerialCache {
|
||||
return &PeerSerialCache{
|
||||
cache: cache.New[string](cacheStore),
|
||||
ctx: ctx,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the entry previously recorded for this peer and whether a valid
|
||||
// entry was found. A cache miss or any read error is reported as a miss so
|
||||
// callers fall back to the full map path.
|
||||
func (c *PeerSerialCache) Get(pubKey string) (peerSyncEntry, bool) {
|
||||
raw, err := c.cache.Get(c.ctx, peerSerialCacheKeyPrefix+pubKey)
|
||||
if err != nil {
|
||||
return peerSyncEntry{}, false
|
||||
}
|
||||
|
||||
entry := peerSyncEntry{}
|
||||
if err := json.Unmarshal([]byte(raw), &entry); err != nil {
|
||||
log.Debugf("peer serial cache: unmarshal entry for %s: %v", pubKey, err)
|
||||
return peerSyncEntry{}, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
// Set records what the server most recently delivered to this peer. Errors are
|
||||
// logged at debug level so cache outages degrade gracefully into the full map
|
||||
// path on the next Sync rather than failing the current Sync.
|
||||
func (c *PeerSerialCache) Set(pubKey string, entry peerSyncEntry) {
|
||||
payload, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
log.Debugf("peer serial cache: marshal entry for %s: %v", pubKey, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.cache.Set(c.ctx, peerSerialCacheKeyPrefix+pubKey, string(payload), store.WithExpiration(c.ttl)); err != nil {
|
||||
log.Debugf("peer serial cache: set entry for %s: %v", pubKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes any cached entry for this peer. Used on Login so the next
|
||||
// Sync always sees a miss and delivers a full map.
|
||||
func (c *PeerSerialCache) Delete(pubKey string) {
|
||||
if err := c.cache.Delete(c.ctx, peerSerialCacheKeyPrefix+pubKey); err != nil {
|
||||
log.Debugf("peer serial cache: delete entry for %s: %v", pubKey, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package grpc
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldSkipNetworkMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
goOS string
|
||||
hit bool
|
||||
cached peerSyncEntry
|
||||
currentSerial uint64
|
||||
incomingMeta uint64
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "android never skips even on clean cache hit",
|
||||
goOS: "android",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 7,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "android uppercase never skips",
|
||||
goOS: "Android",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 7,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "cache miss forces full path",
|
||||
goOS: "linux",
|
||||
hit: false,
|
||||
cached: peerSyncEntry{},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 7,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "serial mismatch forces full path",
|
||||
goOS: "linux",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 41, MetaHash: 7},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 7,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "meta mismatch forces full path",
|
||||
goOS: "linux",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 9,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "clean hit on linux skips",
|
||||
goOS: "linux",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 7,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "clean hit on darwin skips",
|
||||
goOS: "darwin",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 7,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "clean hit on windows skips",
|
||||
goOS: "windows",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 7,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "zero current serial never skips",
|
||||
goOS: "linux",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 0, MetaHash: 7},
|
||||
currentSerial: 0,
|
||||
incomingMeta: 7,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty goos treated as non-android and skips",
|
||||
goOS: "",
|
||||
hit: true,
|
||||
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
|
||||
currentSerial: 42,
|
||||
incomingMeta: 7,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := shouldSkipNetworkMap(tc.goOS, tc.hit, tc.cached, tc.currentSerial, tc.incomingMeta)
|
||||
if got != tc.want {
|
||||
t.Fatalf("shouldSkipNetworkMap(%q, hit=%v, cached=%+v, current=%d, meta=%d) = %v, want %v",
|
||||
tc.goOS, tc.hit, tc.cached, tc.currentSerial, tc.incomingMeta, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
134
management/internals/shared/grpc/peer_serial_cache_test.go
Normal file
134
management/internals/shared/grpc/peer_serial_cache_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
func newTestPeerSerialCache(t *testing.T, ttl, cleanup time.Duration) *PeerSerialCache {
|
||||
t.Helper()
|
||||
s, err := nbcache.NewStore(context.Background(), ttl, cleanup, 100)
|
||||
require.NoError(t, err, "cache store must initialise")
|
||||
return NewPeerSerialCache(context.Background(), s, ttl)
|
||||
}
|
||||
|
||||
func TestPeerSerialCache_GetSetDelete(t *testing.T) {
|
||||
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
|
||||
key := "pubkey-aaa"
|
||||
|
||||
_, hit := c.Get(key)
|
||||
assert.False(t, hit, "empty cache must miss")
|
||||
|
||||
c.Set(key, peerSyncEntry{Serial: 42, MetaHash: 7})
|
||||
|
||||
entry, hit := c.Get(key)
|
||||
require.True(t, hit, "after Set, Get must hit")
|
||||
assert.Equal(t, uint64(42), entry.Serial, "serial roundtrip")
|
||||
assert.Equal(t, uint64(7), entry.MetaHash, "meta hash roundtrip")
|
||||
|
||||
c.Delete(key)
|
||||
_, hit = c.Get(key)
|
||||
assert.False(t, hit, "after Delete, Get must miss")
|
||||
}
|
||||
|
||||
func TestPeerSerialCache_GetMissReturnsZero(t *testing.T) {
|
||||
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
|
||||
|
||||
entry, hit := c.Get("missing")
|
||||
assert.False(t, hit, "miss must report false")
|
||||
assert.Equal(t, peerSyncEntry{}, entry, "miss must return zero value")
|
||||
}
|
||||
|
||||
func TestPeerSerialCache_TTLExpiry(t *testing.T) {
|
||||
c := newTestPeerSerialCache(t, 100*time.Millisecond, 10*time.Millisecond)
|
||||
key := "pubkey-ttl"
|
||||
|
||||
c.Set(key, peerSyncEntry{Serial: 1, MetaHash: 1})
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
_, hit := c.Get(key)
|
||||
assert.False(t, hit, "entry must expire after TTL")
|
||||
}
|
||||
|
||||
func TestPeerSerialCache_OverwriteUpdatesValue(t *testing.T) {
|
||||
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
|
||||
key := "pubkey-overwrite"
|
||||
|
||||
c.Set(key, peerSyncEntry{Serial: 1, MetaHash: 1})
|
||||
c.Set(key, peerSyncEntry{Serial: 99, MetaHash: 123})
|
||||
|
||||
entry, hit := c.Get(key)
|
||||
require.True(t, hit, "overwritten key must still be present")
|
||||
assert.Equal(t, uint64(99), entry.Serial, "overwrite updates serial")
|
||||
assert.Equal(t, uint64(123), entry.MetaHash, "overwrite updates meta hash")
|
||||
}
|
||||
|
||||
func TestPeerSerialCache_IsolatedPerKey(t *testing.T) {
|
||||
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
|
||||
|
||||
c.Set("a", peerSyncEntry{Serial: 1, MetaHash: 1})
|
||||
c.Set("b", peerSyncEntry{Serial: 2, MetaHash: 2})
|
||||
|
||||
a, hitA := c.Get("a")
|
||||
b, hitB := c.Get("b")
|
||||
require.True(t, hitA, "key a must hit")
|
||||
require.True(t, hitB, "key b must hit")
|
||||
assert.Equal(t, uint64(1), a.Serial, "key a serial")
|
||||
assert.Equal(t, uint64(2), b.Serial, "key b serial")
|
||||
|
||||
c.Delete("a")
|
||||
_, hitA = c.Get("a")
|
||||
_, hitB = c.Get("b")
|
||||
assert.False(t, hitA, "deleting a must not affect b")
|
||||
assert.True(t, hitB, "b must remain after a delete")
|
||||
}
|
||||
|
||||
func TestPeerSerialCache_Concurrent(t *testing.T) {
|
||||
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const workers = 50
|
||||
const iterations = 20
|
||||
|
||||
for w := 0; w < workers; w++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
key := "pubkey"
|
||||
for i := 0; i < iterations; i++ {
|
||||
c.Set(key, peerSyncEntry{Serial: uint64(id*iterations + i), MetaHash: uint64(id)})
|
||||
_, _ = c.Get(key)
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
_, hit := c.Get("pubkey")
|
||||
assert.True(t, hit, "cache must survive concurrent Set/Get without deadlock")
|
||||
}
|
||||
|
||||
func TestPeerSerialCache_Redis(t *testing.T) {
|
||||
if os.Getenv(nbcache.RedisStoreEnvVar) == "" {
|
||||
t.Skipf("set %s to run this test against a real Redis", nbcache.RedisStoreEnvVar)
|
||||
}
|
||||
|
||||
s, err := nbcache.NewStore(context.Background(), time.Minute, 10*time.Second, 10)
|
||||
require.NoError(t, err, "redis store must initialise")
|
||||
c := NewPeerSerialCache(context.Background(), s, time.Minute)
|
||||
|
||||
key := "redis-pubkey"
|
||||
c.Set(key, peerSyncEntry{Serial: 42, MetaHash: 7})
|
||||
entry, hit := c.Get(key)
|
||||
require.True(t, hit, "redis hit expected")
|
||||
assert.Equal(t, uint64(42), entry.Serial)
|
||||
c.Delete(key)
|
||||
}
|
||||
@@ -84,9 +84,21 @@ type Server struct {
|
||||
|
||||
reverseProxyManager rpservice.Manager
|
||||
reverseProxyMu sync.RWMutex
|
||||
|
||||
// peerSerialCache lets Sync skip full network map computation when the peer
|
||||
// already has the latest account serial. A nil cache disables the fast path.
|
||||
peerSerialCache *PeerSerialCache
|
||||
|
||||
// fastPathFlag is the runtime kill switch for the Sync fast path. A nil
|
||||
// flag or a flag reporting disabled forces every Sync through the full
|
||||
// network map path.
|
||||
fastPathFlag *FastPathFlag
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
// NewServer creates a new Management server. peerSerialCache and fastPathFlag
|
||||
// are both optional; when either is nil or the flag reports disabled, the
|
||||
// Sync fast path is disabled and every request runs the full map computation,
|
||||
// matching the pre-cache behaviour.
|
||||
func NewServer(
|
||||
config *nbconfig.Config,
|
||||
accountManager account.Manager,
|
||||
@@ -98,6 +110,8 @@ func NewServer(
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
networkMapController network_map.Controller,
|
||||
oAuthConfigProvider idp.OAuthConfigProvider,
|
||||
peerSerialCache *PeerSerialCache,
|
||||
fastPathFlag *FastPathFlag,
|
||||
) (*Server, error) {
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
@@ -145,6 +159,9 @@ func NewServer(
|
||||
|
||||
syncLim: syncLim,
|
||||
syncLimEnabled: syncLimEnabled,
|
||||
|
||||
peerSerialCache: peerSerialCache,
|
||||
fastPathFlag: fastPathFlag,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -294,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
||||
log.WithContext(ctx).Debugf("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
|
||||
@@ -305,6 +322,10 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
if took, err := s.tryFastPathSync(ctx, reqStart, syncStart, accountID, peerKey, peerMeta, realIP, metahash, srv, &unlock); took {
|
||||
return err
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
@@ -319,6 +340,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||
return err
|
||||
}
|
||||
s.recordPeerSyncEntry(peerKey.String(), netMap, metahash)
|
||||
|
||||
updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
@@ -340,7 +362,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart)
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, metahash, updates, srv, syncStart)
|
||||
}
|
||||
|
||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||
@@ -410,8 +432,9 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
// It implements a backpressure mechanism that sends the first update immediately,
|
||||
// then debounces subsequent rapid updates, ensuring only the latest update is sent
|
||||
// after a quiet period.
|
||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||
// after a quiet period. peerMetaHash is forwarded to sendUpdate so the peer-sync
|
||||
// cache can record the serial this peer just received.
|
||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerMetaHash uint64, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
|
||||
// Create a debouncer for this peer connection
|
||||
@@ -436,7 +459,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
// Send immediately (first update or after quiet period)
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerMetaHash, update, srv, streamStartTime); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
@@ -450,7 +473,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
|
||||
for _, pendingUpdate := range pendingUpdates {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerMetaHash, pendingUpdate, srv, streamStartTime); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
@@ -468,7 +491,9 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||
// For MessageTypeNetworkMap updates it records the delivered serial in the
|
||||
// peer-sync cache so a subsequent Sync with the same serial can take the fast path.
|
||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerMetaHash uint64, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
@@ -488,6 +513,9 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
if update.MessageType == network_map.MessageTypeNetworkMap {
|
||||
s.recordPeerSyncEntryFromUpdate(peerKey.String(), update, peerMetaHash)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
@@ -772,6 +800,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
s.invalidatePeerSyncEntry(peerKey.String())
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
if err != nil {
|
||||
|
||||
359
management/internals/shared/grpc/sync_fast_path.go
Normal file
359
management/internals/shared/grpc/sync_fast_path.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// peerGroupFetcher returns the group IDs a peer belongs to. It is a dependency
|
||||
// of buildFastPathResponse so tests can inject a stub without a real store.
|
||||
type peerGroupFetcher func(ctx context.Context, accountID, peerID string) ([]string, error)
|
||||
|
||||
// peerSyncEntry records what the server last delivered to a peer on Sync so we
|
||||
// can decide whether the next Sync can skip the full network map computation.
|
||||
type peerSyncEntry struct {
|
||||
// Serial is the NetworkMap.Serial the server last included in a full map
|
||||
// delivered to this peer.
|
||||
Serial uint64
|
||||
// MetaHash is the metaHash() value of the peer metadata at the time of that
|
||||
// delivery, used to detect a meta change on reconnect.
|
||||
MetaHash uint64
|
||||
}
|
||||
|
||||
// shouldSkipNetworkMap reports whether a Sync request from this peer can be
|
||||
// answered with a lightweight NetbirdConfig-only response instead of a full
|
||||
// map computation. All conditions must hold:
|
||||
// - the peer is not Android (Android's GrpcClient.GetNetworkMap errors on nil map)
|
||||
// - the cache holds an entry for this peer
|
||||
// - the cached serial matches the current account serial
|
||||
// - the cached meta hash matches the incoming meta hash
|
||||
// - the cached serial is non-zero (guard against uninitialised entries)
|
||||
func shouldSkipNetworkMap(goOS string, hit bool, cached peerSyncEntry, currentSerial, incomingMetaHash uint64) bool {
|
||||
if strings.EqualFold(goOS, "android") {
|
||||
return false
|
||||
}
|
||||
if !hit {
|
||||
return false
|
||||
}
|
||||
if cached.Serial == 0 {
|
||||
return false
|
||||
}
|
||||
if cached.Serial != currentSerial {
|
||||
return false
|
||||
}
|
||||
if cached.MetaHash != incomingMetaHash {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// buildFastPathResponse constructs a SyncResponse containing only NetbirdConfig
|
||||
// with fresh TURN/Relay tokens, mirroring the shape used by
|
||||
// TimeBasedAuthSecretsManager when pushing token refreshes. The response omits
|
||||
// NetworkMap, PeerConfig, Checks and RemotePeers; the client keeps its existing
|
||||
// state and only refreshes its control-plane credentials.
|
||||
func buildFastPathResponse(
|
||||
ctx context.Context,
|
||||
cfg *nbconfig.Config,
|
||||
secrets SecretsManager,
|
||||
settingsMgr settings.Manager,
|
||||
fetchGroups peerGroupFetcher,
|
||||
peer *nbpeer.Peer,
|
||||
) *proto.SyncResponse {
|
||||
var turnToken *Token
|
||||
if cfg != nil && cfg.TURNConfig != nil && cfg.TURNConfig.TimeBasedCredentials {
|
||||
if t, err := secrets.GenerateTurnToken(); err == nil {
|
||||
turnToken = t
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("fast path: generate TURN token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var relayToken *Token
|
||||
if cfg != nil && cfg.Relay != nil && len(cfg.Relay.Addresses) > 0 {
|
||||
if t, err := secrets.GenerateRelayToken(); err == nil {
|
||||
relayToken = t
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("fast path: generate relay token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var extraSettings *nbtypes.ExtraSettings
|
||||
extraSettingsStart := time.Now()
|
||||
if es, err := settingsMgr.GetExtraSettings(ctx, peer.AccountID); err != nil {
|
||||
log.WithContext(ctx).Debugf("fast path: get extra settings: %v", err)
|
||||
} else {
|
||||
extraSettings = es
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: GetExtraSettings took %s", time.Since(extraSettingsStart))
|
||||
|
||||
nbConfig := toNetbirdConfig(cfg, turnToken, relayToken, extraSettings)
|
||||
|
||||
var peerGroups []string
|
||||
if fetchGroups != nil {
|
||||
start := time.Now()
|
||||
if ids, err := fetchGroups(ctx, peer.AccountID, peer.ID); err != nil {
|
||||
log.WithContext(ctx).Debugf("fast path: get peer group ids: %v", err)
|
||||
} else {
|
||||
peerGroups = ids
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: get peer groups took %s", time.Since(start))
|
||||
}
|
||||
|
||||
extendStart := time.Now()
|
||||
nbConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
log.WithContext(ctx).Debugf("fast path: ExtendNetBirdConfig took %s", time.Since(extendStart))
|
||||
|
||||
return &proto.SyncResponse{NetbirdConfig: nbConfig}
|
||||
}
|
||||
|
||||
// tryFastPathSync decides whether the current Sync can be answered with a
|
||||
// lightweight NetbirdConfig-only response. When the fast path runs, it takes
|
||||
// over the whole Sync handler (MarkPeerConnected, send, OnPeerConnected,
|
||||
// SetupRefresh, handleUpdates) and the returned took value is true.
|
||||
//
|
||||
// When took is true the caller must return the accompanying err. When took is
|
||||
// false the caller falls through to the existing slow path.
|
||||
func (s *Server) tryFastPathSync(
|
||||
ctx context.Context,
|
||||
reqStart, syncStart time.Time,
|
||||
accountID string,
|
||||
peerKey wgtypes.Key,
|
||||
peerMeta nbpeer.PeerSystemMeta,
|
||||
realIP net.IP,
|
||||
peerMetaHash uint64,
|
||||
srv proto.ManagementService_SyncServer,
|
||||
unlock *func(),
|
||||
) (took bool, err error) {
|
||||
if s.peerSerialCache == nil {
|
||||
return false, nil
|
||||
}
|
||||
if !s.fastPathFlag.Enabled() {
|
||||
return false, nil
|
||||
}
|
||||
if strings.EqualFold(peerMeta.GoOS, "android") {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
networkStart := time.Now()
|
||||
network, err := s.accountManager.GetStore().GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("fast path: lookup account network: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: initial GetAccountNetwork took %s", time.Since(networkStart))
|
||||
|
||||
eligibilityStart := time.Now()
|
||||
cached, hit := s.peerSerialCache.Get(peerKey.String())
|
||||
if !shouldSkipNetworkMap(peerMeta.GoOS, hit, cached, network.CurrentSerial(), peerMetaHash) {
|
||||
log.WithContext(ctx).Debugf("fast path: eligibility check (miss) took %s", time.Since(eligibilityStart))
|
||||
return false, nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: eligibility check (hit) took %s", time.Since(eligibilityStart))
|
||||
|
||||
peer, updates, committed := s.commitFastPath(ctx, accountID, peerKey, realIP, syncStart, network.CurrentSerial())
|
||||
if !committed {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, s.runFastPathSync(ctx, reqStart, syncStart, accountID, peerKey, peer, updates, peerMetaHash, srv, unlock)
|
||||
}
|
||||
|
||||
// commitFastPath fetches the peer, subscribes it to network-map updates,
|
||||
// re-checks the account serial to close the race between the eligibility
|
||||
// check and the subscription, and only then commits MarkPeerConnected. If
|
||||
// the serial advanced in the race window the update-channel subscription is
|
||||
// torn down (no MarkPeerConnected is written, so the slow path is free to
|
||||
// run its own SyncAndMarkPeer cleanly) and the caller falls back to the
|
||||
// slow path. Returns committed=false on any failure that should not block
|
||||
// the slow path from running.
|
||||
func (s *Server) commitFastPath(
|
||||
ctx context.Context,
|
||||
accountID string,
|
||||
peerKey wgtypes.Key,
|
||||
realIP net.IP,
|
||||
syncStart time.Time,
|
||||
expectedSerial uint64,
|
||||
) (*nbpeer.Peer, chan *network_map.UpdateMessage, bool) {
|
||||
commitStart := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("fast path: commitFastPath took %s", time.Since(commitStart))
|
||||
}()
|
||||
|
||||
getPeerStart := time.Now()
|
||||
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("fast path: lookup peer %s: %v", peerKey.String(), err)
|
||||
return nil, nil, false
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: GetPeerByPeerPubKey took %s", time.Since(getPeerStart))
|
||||
|
||||
onConnectedStart := time.Now()
|
||||
updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("fast path: notify peer connected for %s: %v", peerKey.String(), err)
|
||||
return nil, nil, false
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: OnPeerConnected took %s", time.Since(onConnectedStart))
|
||||
|
||||
recheckStart := time.Now()
|
||||
network, err := s.accountManager.GetStore().GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("fast path: re-check account network: %v", err)
|
||||
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
|
||||
return nil, nil, false
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: re-check GetAccountNetwork took %s", time.Since(recheckStart))
|
||||
|
||||
if network.CurrentSerial() != expectedSerial {
|
||||
log.WithContext(ctx).Debugf("fast path: serial advanced from %d to %d after subscribe, falling back to slow path for peer %s", expectedSerial, network.CurrentSerial(), peerKey.String())
|
||||
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
markStart := time.Now()
|
||||
if err := s.accountManager.MarkPeerConnected(ctx, peerKey.String(), true, realIP, accountID, syncStart); err != nil {
|
||||
log.WithContext(ctx).Warnf("fast path: mark connected for peer %s: %v", peerKey.String(), err)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: MarkPeerConnected took %s", time.Since(markStart))
|
||||
|
||||
return peer, updates, true
|
||||
}
|
||||
|
||||
// runFastPathSync executes the fast path: send the lean response, kick off
|
||||
// token refresh, release the per-peer lock, then block on handleUpdates until
|
||||
// the stream is closed. Peer lookup and subscription have already been
|
||||
// performed by commitFastPath so the race between eligibility check and
|
||||
// subscription is already closed.
|
||||
func (s *Server) runFastPathSync(
|
||||
ctx context.Context,
|
||||
reqStart, syncStart time.Time,
|
||||
accountID string,
|
||||
peerKey wgtypes.Key,
|
||||
peer *nbpeer.Peer,
|
||||
updates chan *network_map.UpdateMessage,
|
||||
peerMetaHash uint64,
|
||||
srv proto.ManagementService_SyncServer,
|
||||
unlock *func(),
|
||||
) error {
|
||||
sendStart := time.Now()
|
||||
if err := s.sendFastPathResponse(ctx, peerKey, peer, srv); err != nil {
|
||||
log.WithContext(ctx).Debugf("fast path: send response for peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("fast path: sendFastPathResponse took %s", time.Since(sendStart))
|
||||
|
||||
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
|
||||
|
||||
if unlock != nil && *unlock != nil {
|
||||
(*unlock)()
|
||||
*unlock = nil
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("Sync (fast path) took %s", time.Since(reqStart))
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, peerMetaHash, updates, srv, syncStart)
|
||||
}
|
||||
|
||||
// sendFastPathResponse builds a NetbirdConfig-only SyncResponse, encrypts it
|
||||
// with the peer's WireGuard key and pushes it over the stream.
|
||||
func (s *Server) sendFastPathResponse(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, srv proto.ManagementService_SyncServer) error {
|
||||
resp := buildFastPathResponse(ctx, s.config, s.secretsManager, s.settingsManager, s.fetchPeerGroups, peer)
|
||||
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed getting server key")
|
||||
}
|
||||
|
||||
body, err := encryption.EncryptMessage(peerKey, key, resp)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "error encrypting fast-path sync response")
|
||||
}
|
||||
|
||||
if err := srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: key.PublicKey().String(),
|
||||
Body: body,
|
||||
}); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed sending fast-path sync response: %v", err)
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchPeerGroups is the dependency injected into buildFastPathResponse in
|
||||
// production. A nil accountManager store is treated as "no groups".
|
||||
func (s *Server) fetchPeerGroups(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||
return s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
// recordPeerSyncEntry writes the serial just delivered to this peer so a
|
||||
// subsequent reconnect can take the fast path. Called after the slow path's
|
||||
// sendInitialSync has pushed a full map. A nil cache disables the fast path.
|
||||
func (s *Server) recordPeerSyncEntry(peerKey string, netMap *nbtypes.NetworkMap, peerMetaHash uint64) {
|
||||
if s.peerSerialCache == nil {
|
||||
return
|
||||
}
|
||||
if !s.fastPathFlag.Enabled() {
|
||||
return
|
||||
}
|
||||
if netMap == nil || netMap.Network == nil {
|
||||
return
|
||||
}
|
||||
serial := netMap.Network.CurrentSerial()
|
||||
if serial == 0 {
|
||||
return
|
||||
}
|
||||
s.peerSerialCache.Set(peerKey, peerSyncEntry{Serial: serial, MetaHash: peerMetaHash})
|
||||
}
|
||||
|
||||
// recordPeerSyncEntryFromUpdate is the sendUpdate equivalent of
|
||||
// recordPeerSyncEntry: it extracts the serial from a streamed NetworkMap update
|
||||
// so the cache stays in sync with what the peer most recently received.
|
||||
func (s *Server) recordPeerSyncEntryFromUpdate(peerKey string, update *network_map.UpdateMessage, peerMetaHash uint64) {
|
||||
if s.peerSerialCache == nil || update == nil || update.Update == nil || update.Update.NetworkMap == nil {
|
||||
return
|
||||
}
|
||||
if !s.fastPathFlag.Enabled() {
|
||||
return
|
||||
}
|
||||
serial := update.Update.NetworkMap.GetSerial()
|
||||
if serial == 0 {
|
||||
return
|
||||
}
|
||||
s.peerSerialCache.Set(peerKey, peerSyncEntry{Serial: serial, MetaHash: peerMetaHash})
|
||||
}
|
||||
|
||||
// invalidatePeerSyncEntry is called after a successful Login so the next Sync
|
||||
// is guaranteed to deliver a full map, picking up whatever changed in the
|
||||
// login (SSH key rotation, approval state, user binding, etc.).
|
||||
func (s *Server) invalidatePeerSyncEntry(peerKey string) {
|
||||
if s.peerSerialCache == nil {
|
||||
return
|
||||
}
|
||||
s.peerSerialCache.Delete(peerKey)
|
||||
}
|
||||
163
management/internals/shared/grpc/sync_fast_path_response_test.go
Normal file
163
management/internals/shared/grpc/sync_fast_path_response_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func fastPathTestPeer() *nbpeer.Peer {
|
||||
return &nbpeer.Peer{
|
||||
ID: "peer-id",
|
||||
AccountID: "account-id",
|
||||
Key: "pubkey",
|
||||
}
|
||||
}
|
||||
|
||||
func fastPathTestSecrets(t *testing.T, turn *config.TURNConfig, relay *config.Relay) *TimeBasedAuthSecretsManager {
|
||||
t.Helper()
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMock := settings.NewMockManager(ctrl)
|
||||
secrets, err := NewTimeBasedAuthSecretsManager(peersManager, turn, relay, settingsMock, groups.NewManagerMock())
|
||||
require.NoError(t, err, "secrets manager initialisation must succeed")
|
||||
return secrets
|
||||
}
|
||||
|
||||
func noGroupsFetcher(context.Context, string, string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestBuildFastPathResponse_TimeBasedTURNAndRelay_FreshTokens(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
turnCfg := &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: "turn-secret",
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}
|
||||
relayCfg := &config.Relay{
|
||||
Addresses: []string{"rel.example:443"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: "relay-secret",
|
||||
}
|
||||
cfg := &config.Config{
|
||||
TURNConfig: turnCfg,
|
||||
Relay: relayCfg,
|
||||
Signal: &config.Host{URI: "signal.example:443", Proto: config.HTTPS},
|
||||
Stuns: []*config.Host{{URI: "stun.example:3478", Proto: config.UDP}},
|
||||
}
|
||||
|
||||
secrets := fastPathTestSecrets(t, turnCfg, relayCfg)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMock := settings.NewMockManager(ctrl)
|
||||
settingsMock.EXPECT().GetExtraSettings(gomock.Any(), "account-id").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
|
||||
resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock, noGroupsFetcher, fastPathTestPeer())
|
||||
|
||||
require.NotNil(t, resp, "response must not be nil")
|
||||
assert.Nil(t, resp.NetworkMap, "fast path must omit NetworkMap")
|
||||
assert.Nil(t, resp.PeerConfig, "fast path must omit PeerConfig")
|
||||
assert.Empty(t, resp.Checks, "fast path must omit posture checks")
|
||||
assert.Empty(t, resp.RemotePeers, "fast path must omit remote peers")
|
||||
|
||||
require.NotNil(t, resp.NetbirdConfig, "NetbirdConfig must be present on fast path")
|
||||
require.Len(t, resp.NetbirdConfig.Turns, 1, "time-based TURN credentials must be present")
|
||||
assert.NotEmpty(t, resp.NetbirdConfig.Turns[0].User, "TURN user must be populated")
|
||||
assert.NotEmpty(t, resp.NetbirdConfig.Turns[0].Password, "TURN password must be populated")
|
||||
|
||||
require.NotNil(t, resp.NetbirdConfig.Relay, "Relay config must be present when configured")
|
||||
assert.NotEmpty(t, resp.NetbirdConfig.Relay.TokenPayload, "relay token payload must be populated")
|
||||
assert.NotEmpty(t, resp.NetbirdConfig.Relay.TokenSignature, "relay token signature must be populated")
|
||||
assert.Equal(t, []string{"rel.example:443"}, resp.NetbirdConfig.Relay.Urls, "relay URLs passthrough")
|
||||
|
||||
require.NotNil(t, resp.NetbirdConfig.Signal, "Signal config must be present when configured")
|
||||
assert.Equal(t, "signal.example:443", resp.NetbirdConfig.Signal.Uri, "signal URI passthrough")
|
||||
require.Len(t, resp.NetbirdConfig.Stuns, 1, "STUNs must be passed through")
|
||||
assert.Equal(t, "stun.example:3478", resp.NetbirdConfig.Stuns[0].Uri, "STUN URI passthrough")
|
||||
}
|
||||
|
||||
func TestBuildFastPathResponse_StaticTURNCredentials(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
staticHost := &config.Host{
|
||||
URI: "turn:static.example:3478",
|
||||
Proto: config.UDP,
|
||||
Username: "preset-user",
|
||||
Password: "preset-pass",
|
||||
}
|
||||
turnCfg := &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: "turn-secret",
|
||||
Turns: []*config.Host{staticHost},
|
||||
TimeBasedCredentials: false,
|
||||
}
|
||||
cfg := &config.Config{TURNConfig: turnCfg}
|
||||
|
||||
// Use a relay-free secrets manager; static TURN path does not consult it.
|
||||
secrets := fastPathTestSecrets(t, turnCfg, nil)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMock := settings.NewMockManager(ctrl)
|
||||
settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
|
||||
resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock, noGroupsFetcher, fastPathTestPeer())
|
||||
|
||||
require.NotNil(t, resp.NetbirdConfig)
|
||||
require.Len(t, resp.NetbirdConfig.Turns, 1, "static TURN must appear in response")
|
||||
assert.Equal(t, "preset-user", resp.NetbirdConfig.Turns[0].User, "static user passthrough")
|
||||
assert.Equal(t, "preset-pass", resp.NetbirdConfig.Turns[0].Password, "static password passthrough")
|
||||
assert.Nil(t, resp.NetbirdConfig.Relay, "no Relay when Relay config is nil")
|
||||
}
|
||||
|
||||
func TestBuildFastPathResponse_NoRelayConfigured_NoRelaySection(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
secrets := fastPathTestSecrets(t, nil, nil)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMock := settings.NewMockManager(ctrl)
|
||||
settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
|
||||
resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock, noGroupsFetcher, fastPathTestPeer())
|
||||
require.NotNil(t, resp.NetbirdConfig, "NetbirdConfig must be non-nil even without relay/turn")
|
||||
assert.Nil(t, resp.NetbirdConfig.Relay, "Relay must be absent when not configured")
|
||||
assert.Empty(t, resp.NetbirdConfig.Turns, "Turns must be empty when not configured")
|
||||
}
|
||||
|
||||
func TestBuildFastPathResponse_ExtraSettingsErrorStillReturnsResponse(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
secrets := fastPathTestSecrets(t, nil, nil)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMock := settings.NewMockManager(ctrl)
|
||||
settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(nil, assertAnError).AnyTimes()
|
||||
|
||||
resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock, noGroupsFetcher, fastPathTestPeer())
|
||||
require.NotNil(t, resp, "extra settings failure must degrade gracefully into an empty fast-path response")
|
||||
assert.Nil(t, resp.NetworkMap, "NetworkMap still omitted on degraded path")
|
||||
}
|
||||
|
||||
// assertAnError is a sentinel used by fast-path tests that need to simulate a
|
||||
// dependency failure without caring about the error value.
|
||||
var assertAnError = errForTests("simulated")
|
||||
|
||||
type errForTests string
|
||||
|
||||
func (e errForTests) Error() string { return string(e) }
|
||||
@@ -391,7 +391,9 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
|
||||
peerSerialCache := nbgrpc.NewPeerSerialCache(ctx, cacheStore, time.Minute)
|
||||
fastPathFlag := nbgrpc.NewFastPathFlag(true)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, peerSerialCache, fastPathFlag)
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
@@ -256,6 +256,8 @@ func startServer(
|
||||
server.MockIntegratedValidator{},
|
||||
networkMapController,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating management server: %v", err)
|
||||
|
||||
297
management/server/sync_fast_path_test.go
Normal file
297
management/server/sync_fast_path_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// skipOnWindows skips the calling test on Windows. The in-process gRPC
|
||||
// harness uses Unix socket / path conventions that do not cleanly map to
|
||||
// Windows.
|
||||
func skipOnWindows(t *testing.T) {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping on windows; harness uses unix path conventions")
|
||||
}
|
||||
}
|
||||
|
||||
func fastPathTestConfig(t *testing.T) *config.Config {
|
||||
t.Helper()
|
||||
return &config.Config{
|
||||
Datadir: t.TempDir(),
|
||||
Stuns: []*config.Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.example:3478",
|
||||
}},
|
||||
TURNConfig: &config.TURNConfig{
|
||||
TimeBasedCredentials: true,
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "turn-secret",
|
||||
Turns: []*config.Host{{
|
||||
Proto: "udp",
|
||||
URI: "turn:turn.example:3478",
|
||||
}},
|
||||
},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"rel.example:443"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "relay-secret",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "signal.example:10000",
|
||||
},
|
||||
HttpConfig: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// openSync opens a Sync stream with the given meta and returns the decoded first
|
||||
// SyncResponse plus a cancel function. The caller must call cancel() to release
|
||||
// server-side routines before opening a new stream for the same peer.
|
||||
func openSync(t *testing.T, client mgmtProto.ManagementServiceClient, serverKey, peerKey wgtypes.Key, meta *mgmtProto.PeerSystemMeta) (*mgmtProto.SyncResponse, context.CancelFunc) {
|
||||
t.Helper()
|
||||
|
||||
req := &mgmtProto.SyncRequest{Meta: meta}
|
||||
body, err := encryption.EncryptMessage(serverKey, peerKey, req)
|
||||
require.NoError(t, err, "encrypt sync request")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
stream, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerKey.PublicKey().String(),
|
||||
Body: body,
|
||||
})
|
||||
require.NoError(t, err, "open sync stream")
|
||||
|
||||
enc := &mgmtProto.EncryptedMessage{}
|
||||
require.NoError(t, stream.RecvMsg(enc), "receive first sync response")
|
||||
|
||||
resp := &mgmtProto.SyncResponse{}
|
||||
require.NoError(t, encryption.DecryptMessage(serverKey, peerKey, enc.Body, resp), "decrypt sync response")
|
||||
|
||||
return resp, cancel
|
||||
}
|
||||
|
||||
// waitForPeerDisconnect polls until the account manager reports the peer as
|
||||
// disconnected (Status.Connected == false), which happens once the server's
|
||||
// handleUpdates goroutine has run cancelPeerRoutines for the cancelled
|
||||
// stream. The deadline is bounded so a stuck server fails the test rather
|
||||
// than hanging. Replaces the former fixed 50ms sleep which was CI-flaky
|
||||
// under load or with the race detector on.
|
||||
func waitForPeerDisconnect(t *testing.T, am *DefaultAccountManager, peerPubKey string) {
|
||||
t.Helper()
|
||||
require.Eventually(t, func() bool {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return !peer.Status.Connected
|
||||
}, 2*time.Second, 10*time.Millisecond, "peer %s should be marked disconnected after stream cancel", peerPubKey)
|
||||
}
|
||||
|
||||
func baseLinuxMeta() *mgmtProto.PeerSystemMeta {
|
||||
return &mgmtProto.PeerSystemMeta{
|
||||
Hostname: "linux-host",
|
||||
GoOS: "linux",
|
||||
OS: "linux",
|
||||
Platform: "x86_64",
|
||||
Kernel: "5.15.0",
|
||||
NetbirdVersion: "0.70.0",
|
||||
}
|
||||
}
|
||||
|
||||
func androidMeta() *mgmtProto.PeerSystemMeta {
|
||||
return &mgmtProto.PeerSystemMeta{
|
||||
Hostname: "android-host",
|
||||
GoOS: "android",
|
||||
OS: "android",
|
||||
Platform: "arm64",
|
||||
Kernel: "4.19",
|
||||
NetbirdVersion: "0.70.0",
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncFastPath_FirstSync_SendsFullMap(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, cancel := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta())
|
||||
defer cancel()
|
||||
|
||||
require.NotNil(t, resp.NetworkMap, "first sync for a fresh peer must deliver a full NetworkMap")
|
||||
assert.NotNil(t, resp.NetbirdConfig, "NetbirdConfig must accompany the full map")
|
||||
}
|
||||
|
||||
func TestSyncFastPath_SecondSync_MatchingSerial_SkipsMap(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
mgmtServer, am, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
first, cancel1 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta())
|
||||
require.NotNil(t, first.NetworkMap, "first sync primes cache with a full map")
|
||||
cancel1()
|
||||
waitForPeerDisconnect(t, am, keys[0].PublicKey().String())
|
||||
|
||||
second, cancel2 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta())
|
||||
defer cancel2()
|
||||
|
||||
assert.Nil(t, second.NetworkMap, "second sync with unchanged state must omit NetworkMap")
|
||||
require.NotNil(t, second.NetbirdConfig, "fast path must still deliver NetbirdConfig")
|
||||
assert.NotEmpty(t, second.NetbirdConfig.Turns, "time-based TURN credentials must be refreshed on fast path")
|
||||
require.NotNil(t, second.NetbirdConfig.Relay, "relay config must be present on fast path")
|
||||
assert.NotEmpty(t, second.NetbirdConfig.Relay.TokenPayload, "relay token must be refreshed on fast path")
|
||||
}
|
||||
|
||||
func TestSyncFastPath_AndroidNeverSkips(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
mgmtServer, am, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
first, cancel1 := openSync(t, client, *serverKey, *keys[0], androidMeta())
|
||||
require.NotNil(t, first.NetworkMap, "android first sync must deliver a full map")
|
||||
cancel1()
|
||||
waitForPeerDisconnect(t, am, keys[0].PublicKey().String())
|
||||
|
||||
second, cancel2 := openSync(t, client, *serverKey, *keys[0], androidMeta())
|
||||
defer cancel2()
|
||||
|
||||
require.NotNil(t, second.NetworkMap, "android reconnects must never take the fast path")
|
||||
}
|
||||
|
||||
func TestSyncFastPath_MetaChanged_SendsFullMap(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
mgmtServer, am, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
first, cancel1 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta())
|
||||
require.NotNil(t, first.NetworkMap, "first sync primes cache")
|
||||
cancel1()
|
||||
waitForPeerDisconnect(t, am, keys[0].PublicKey().String())
|
||||
|
||||
changed := baseLinuxMeta()
|
||||
changed.Hostname = "linux-host-renamed"
|
||||
second, cancel2 := openSync(t, client, *serverKey, *keys[0], changed)
|
||||
defer cancel2()
|
||||
|
||||
require.NotNil(t, second.NetworkMap, "meta change must force a full map even when serial matches")
|
||||
}
|
||||
|
||||
func TestSyncFastPath_LoginInvalidatesCache(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
mgmtServer, am, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = loginPeerWithValidSetupKey(key, client)
|
||||
require.NoError(t, err, "initial login must succeed")
|
||||
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
first, cancel1 := openSync(t, client, *serverKey, key, baseLinuxMeta())
|
||||
require.NotNil(t, first.NetworkMap, "first sync primes cache")
|
||||
cancel1()
|
||||
waitForPeerDisconnect(t, am, key.PublicKey().String())
|
||||
|
||||
// A subsequent login (e.g. SSH key rotation, re-auth) must clear the cache.
|
||||
_, err = loginPeerWithValidSetupKey(key, client)
|
||||
require.NoError(t, err, "second login must succeed")
|
||||
|
||||
second, cancel2 := openSync(t, client, *serverKey, key, baseLinuxMeta())
|
||||
defer cancel2()
|
||||
require.NotNil(t, second.NetworkMap, "Login must invalidate the cache so the next Sync delivers a full map")
|
||||
}
|
||||
|
||||
func TestSyncFastPath_OtherPeerRegistered_ForcesFullMap(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
mgmtServer, am, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
first, cancel1 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta())
|
||||
require.NotNil(t, first.NetworkMap, "first sync primes cache at serial N")
|
||||
cancel1()
|
||||
waitForPeerDisconnect(t, am, keys[0].PublicKey().String())
|
||||
|
||||
// Registering another peer bumps the account serial via IncrementNetworkSerial.
|
||||
_, err = registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
second, cancel2 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta())
|
||||
defer cancel2()
|
||||
require.NotNil(t, second.NetworkMap, "serial advance must force a full map even if meta is unchanged")
|
||||
}
|
||||
181
management/server/sync_legacy_wire_test.go
Normal file
181
management/server/sync_legacy_wire_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto" //nolint:staticcheck // matches the generator
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// sendWireFixture replays a frozen SyncRequest wire fixture as `peerKey` and
|
||||
// returns the decoded first SyncResponse plus a cancel function. The caller
|
||||
// must invoke cancel() so the server releases per-peer routines.
|
||||
func sendWireFixture(t *testing.T, client mgmtProto.ManagementServiceClient, serverKey, peerKey wgtypes.Key, fixturePath string) (*mgmtProto.SyncResponse, context.CancelFunc) {
|
||||
t.Helper()
|
||||
|
||||
raw, err := os.ReadFile(fixturePath)
|
||||
require.NoError(t, err, "read fixture %s", fixturePath)
|
||||
|
||||
req := &mgmtProto.SyncRequest{}
|
||||
require.NoError(t, proto.Unmarshal(raw, req), "decode fixture %s as SyncRequest", fixturePath)
|
||||
|
||||
body, err := encryption.EncryptMessage(serverKey, peerKey, req)
|
||||
require.NoError(t, err, "encrypt sync request")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
stream, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerKey.PublicKey().String(),
|
||||
Body: body,
|
||||
})
|
||||
require.NoError(t, err, "open sync stream")
|
||||
|
||||
enc := &mgmtProto.EncryptedMessage{}
|
||||
require.NoError(t, stream.RecvMsg(enc), "receive first sync response")
|
||||
|
||||
resp := &mgmtProto.SyncResponse{}
|
||||
require.NoError(t, encryption.DecryptMessage(serverKey, peerKey, enc.Body, resp), "decrypt sync response")
|
||||
return resp, cancel
|
||||
}
|
||||
|
||||
func TestSync_WireFixture_LegacyClients_AlwaysReceiveFullMap(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
cases := []struct {
|
||||
name string
|
||||
fixture string
|
||||
}{
|
||||
{"v0.20.0 empty SyncRequest", "testdata/sync_request_wire/v0_20_0.bin"},
|
||||
{"v0.40.0 SyncRequest with Meta", "testdata/sync_request_wire/v0_40_0.bin"},
|
||||
{"v0.60.0 SyncRequest with Meta", "testdata/sync_request_wire/v0_60_0.bin"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
abs, err := filepath.Abs(tc.fixture)
|
||||
require.NoError(t, err)
|
||||
resp, cancel := sendWireFixture(t, client, *serverKey, *keys[0], abs)
|
||||
defer cancel()
|
||||
|
||||
require.NotNil(t, resp.NetworkMap, "legacy client first Sync must deliver a full NetworkMap")
|
||||
require.NotNil(t, resp.NetbirdConfig, "legacy client first Sync must include NetbirdConfig")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSync_WireFixture_LegacyClient_ReconnectStillGetsFullMap(t *testing.T) {
|
||||
// v0.40.x clients call GrpcClient.GetNetworkMap on every OS during
|
||||
// readInitialSettings — they error on nil NetworkMap. Without extra opt-in
|
||||
// signalling there is no way for the server to know this is a GetNetworkMap
|
||||
// call rather than a main Sync, so the server's fast path would break them
|
||||
// on reconnect. This test pins the currently accepted tradeoff: a legacy
|
||||
// v0.40 client gets a full map on the first Sync but a reconnect with an
|
||||
// unchanged metaHash hits the primed cache and goes through the fast path.
|
||||
// When a future proto opt-in lets the server distinguish these clients,
|
||||
// this assertion must be tightened to require.NotNil(second.NetworkMap).
|
||||
skipOnWindows(t)
|
||||
mgmtServer, am, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
abs, err := filepath.Abs("testdata/sync_request_wire/v0_40_0.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
first, cancel1 := sendWireFixture(t, client, *serverKey, *keys[0], abs)
|
||||
require.NotNil(t, first.NetworkMap, "first legacy sync receives full map and primes cache")
|
||||
cancel1()
|
||||
waitForPeerDisconnect(t, am, keys[0].PublicKey().String())
|
||||
|
||||
second, cancel2 := sendWireFixture(t, client, *serverKey, *keys[0], abs)
|
||||
defer cancel2()
|
||||
require.Nil(t, second.NetworkMap, "documented legacy-reconnect tradeoff: warm cache entry causes fast path; update when proto opt-in lands")
|
||||
require.NotNil(t, second.NetbirdConfig, "fast path still delivers NetbirdConfig")
|
||||
}
|
||||
|
||||
func TestSync_WireFixture_AndroidReconnect_NeverSkips(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
mgmtServer, am, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
abs, err := filepath.Abs("testdata/sync_request_wire/android_current.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
first, cancel1 := sendWireFixture(t, client, *serverKey, *keys[0], abs)
|
||||
require.NotNil(t, first.NetworkMap, "android first sync must deliver a full map")
|
||||
cancel1()
|
||||
waitForPeerDisconnect(t, am, keys[0].PublicKey().String())
|
||||
|
||||
second, cancel2 := sendWireFixture(t, client, *serverKey, *keys[0], abs)
|
||||
defer cancel2()
|
||||
require.NotNil(t, second.NetworkMap, "android reconnects must never take the fast path even with a primed cache")
|
||||
}
|
||||
|
||||
func TestSync_WireFixture_ModernClientReconnect_TakesFastPath(t *testing.T) {
|
||||
skipOnWindows(t)
|
||||
mgmtServer, am, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t))
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, conn, err := createRawClient(addr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
keys, err := registerPeers(1, client)
|
||||
require.NoError(t, err)
|
||||
serverKey, err := getServerKey(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
abs, err := filepath.Abs("testdata/sync_request_wire/current.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
first, cancel1 := sendWireFixture(t, client, *serverKey, *keys[0], abs)
|
||||
require.NotNil(t, first.NetworkMap, "modern first sync primes cache")
|
||||
cancel1()
|
||||
waitForPeerDisconnect(t, am, keys[0].PublicKey().String())
|
||||
|
||||
second, cancel2 := sendWireFixture(t, client, *serverKey, *keys[0], abs)
|
||||
defer cancel2()
|
||||
require.Nil(t, second.NetworkMap, "modern reconnect with unchanged state must skip the NetworkMap")
|
||||
require.NotNil(t, second.NetbirdConfig, "fast path still delivers NetbirdConfig")
|
||||
}
|
||||
2
management/server/testdata/sync_request_wire/.gitignore
vendored
Normal file
2
management/server/testdata/sync_request_wire/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
current.bin
|
||||
android_current.bin
|
||||
31
management/server/testdata/sync_request_wire/README.md
vendored
Normal file
31
management/server/testdata/sync_request_wire/README.md
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
# SyncRequest wire-format fixtures
|
||||
|
||||
These files are the byte-for-byte contents of the `SyncRequest` proto a netbird
|
||||
client of each listed version would put on the wire. `sync_legacy_wire_test.go`
|
||||
decodes each file, wraps it in the current `EncryptedMessage` envelope and
|
||||
replays it through the in-process gRPC server to prove that the peer-sync fast
|
||||
path does not break historical clients.
|
||||
|
||||
File | Client era | Notes
|
||||
-----|-----------|------
|
||||
`v0_20_0.bin` | v0.20.x | `message SyncRequest {}` — no fields on the wire. Main Sync loop in v0.20 gracefully skips nil `NetworkMap`, so the fixture is expected to get a full map (empty Sync payload → cache miss → slow path). **Checked in — frozen snapshot.**
|
||||
`v0_40_0.bin` | v0.40.x | First release with `Meta` at tag 1. v0.40 calls `GrpcClient.GetNetworkMap` on every OS; fixture must continue to produce a full map. **Checked in — frozen snapshot.**
|
||||
`v0_60_0.bin` | v0.60.x | Same SyncRequest shape as v0.40 but tagged with a newer `NetbirdVersion`. **Checked in — frozen snapshot.**
|
||||
`current.bin` | latest | Fully-populated `PeerSystemMeta`. **Not checked in — regenerated at CI time by `generate.go`.**
|
||||
`android_current.bin` | latest, Android | Same shape as `current.bin` with `GoOS=android`; the server must never take the fast path even after the cache is primed. **Not checked in — regenerated at CI time by `generate.go`.**
|
||||
|
||||
## Regenerating
|
||||
|
||||
`generate.go` writes only `current.bin` and `android_current.bin`. CI invokes it
|
||||
before running the management test suite:
|
||||
|
||||
```sh
|
||||
go run ./management/server/testdata/sync_request_wire/generate.go
|
||||
```
|
||||
|
||||
Run the same command locally if you are running the wire tests by hand.
|
||||
|
||||
The three legacy fixtures are intentionally frozen. Do not regenerate them —
|
||||
their value is that they survive proto changes unchanged, so a future proto
|
||||
change that silently breaks the old wire format is caught by CI replaying the
|
||||
frozen bytes and failing to decode them.
|
||||
73
management/server/testdata/sync_request_wire/generate.go
vendored
Normal file
73
management/server/testdata/sync_request_wire/generate.go
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build ignore
|
||||
|
||||
// generate.go produces the SyncRequest wire-format fixtures that the current
|
||||
// netbird client (and the android variant) put on the wire. These two files
|
||||
// are regenerated at CI time — run with:
|
||||
//
|
||||
// go run ./management/server/testdata/sync_request_wire/generate.go
|
||||
//
|
||||
// The legacy fixtures (v0_20_0.bin, v0_40_0.bin, v0_60_0.bin) are frozen
|
||||
// snapshots of what older clients sent. They are checked in and intentionally
|
||||
// never regenerated here, so a future proto change that silently breaks the
|
||||
// old wire format is caught by CI replaying the frozen bytes.
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/golang/protobuf/proto" //nolint:staticcheck // wire-format stability
|
||||
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
outDir := filepath.Join("management", "server", "testdata", "sync_request_wire")
|
||||
if err := os.MkdirAll(outDir, 0o755); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "mkdir %s: %v\n", outDir, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fixtures := map[string]*mgmtProto.SyncRequest{
|
||||
// current: fully-populated meta a modern client would send.
|
||||
"current.bin": {
|
||||
Meta: &mgmtProto.PeerSystemMeta{
|
||||
Hostname: "modern-host",
|
||||
GoOS: "linux",
|
||||
OS: "linux",
|
||||
Platform: "x86_64",
|
||||
Kernel: "6.5.0",
|
||||
NetbirdVersion: "0.70.0",
|
||||
UiVersion: "0.70.0",
|
||||
KernelVersion: "6.5.0-rc1",
|
||||
},
|
||||
},
|
||||
|
||||
// android: exercises the never-skip branch regardless of cache state.
|
||||
"android_current.bin": {
|
||||
Meta: &mgmtProto.PeerSystemMeta{
|
||||
Hostname: "android-host",
|
||||
GoOS: "android",
|
||||
OS: "android",
|
||||
Platform: "arm64",
|
||||
Kernel: "4.19",
|
||||
NetbirdVersion: "0.70.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, msg := range fixtures {
|
||||
payload, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "marshal %s: %v\n", name, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
path := filepath.Join(outDir, name)
|
||||
if err := os.WriteFile(path, payload, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "write %s: %v\n", path, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("wrote %s (%d bytes)\n", path, len(payload))
|
||||
}
|
||||
}
|
||||
0
management/server/testdata/sync_request_wire/v0_20_0.bin
vendored
Normal file
0
management/server/testdata/sync_request_wire/v0_20_0.bin
vendored
Normal file
3
management/server/testdata/sync_request_wire/v0_40_0.bin
vendored
Normal file
3
management/server/testdata/sync_request_wire/v0_40_0.bin
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
|
||||
0
|
||||
v40-hostlinux4.15.0*x86_642linux:0.40.0
|
||||
3
management/server/testdata/sync_request_wire/v0_60_0.bin
vendored
Normal file
3
management/server/testdata/sync_request_wire/v0_60_0.bin
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
|
||||
0
|
||||
v60-hostlinux5.15.0*x86_642linux:0.60.0
|
||||
@@ -138,7 +138,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user