mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-19 23:29:56 +00:00
406 lines
12 KiB
Go
406 lines
12 KiB
Go
package grpc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/metadata"
|
|
|
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
|
"github.com/netbirdio/netbird/shared/management/proto"
|
|
)
|
|
|
|
// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records
|
|
// sent messages and returns pre-loaded ack responses from Recv.
|
|
type syncRecordingStream struct {
|
|
grpc.ServerStream
|
|
|
|
mu sync.Mutex
|
|
sent []*proto.SyncMappingsResponse
|
|
recvMsgs []*proto.SyncMappingsRequest
|
|
recvIdx int
|
|
}
|
|
|
|
func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.sent = append(s.sent, m)
|
|
return nil
|
|
}
|
|
|
|
func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.recvIdx >= len(s.recvMsgs) {
|
|
return nil, fmt.Errorf("no more recv messages")
|
|
}
|
|
msg := s.recvMsgs[s.recvIdx]
|
|
s.recvIdx++
|
|
return msg, nil
|
|
}
|
|
|
|
func (s *syncRecordingStream) Context() context.Context { return context.Background() }
|
|
func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil }
|
|
func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil }
|
|
func (s *syncRecordingStream) SetTrailer(metadata.MD) {}
|
|
func (s *syncRecordingStream) SendMsg(any) error { return nil }
|
|
func (s *syncRecordingStream) RecvMsg(any) error { return nil }
|
|
|
|
func ackMsg() *proto.SyncMappingsRequest {
|
|
return &proto.SyncMappingsRequest{
|
|
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
|
}
|
|
}
|
|
|
|
func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) {
|
|
const cluster = "cluster.example.com"
|
|
const batchSize = 3
|
|
const totalServices = 7 // 3 + 3 + 1 → 3 batches, 2 acks needed (not after last)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mgr := rpservice.NewMockManager(ctrl)
|
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
|
|
|
s := newSnapshotTestServer(t, batchSize)
|
|
s.serviceManager = mgr
|
|
|
|
// Provide 2 acks — one after each non-final batch.
|
|
stream := &syncRecordingStream{
|
|
recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg()},
|
|
}
|
|
conn := &proxyConnection{
|
|
proxyID: "proxy-a",
|
|
address: cluster,
|
|
syncStream: stream,
|
|
}
|
|
|
|
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches")
|
|
|
|
assert.Len(t, stream.sent[0].Mapping, 3)
|
|
assert.False(t, stream.sent[0].InitialSyncComplete)
|
|
|
|
assert.Len(t, stream.sent[1].Mapping, 3)
|
|
assert.False(t, stream.sent[1].InitialSyncComplete)
|
|
|
|
assert.Len(t, stream.sent[2].Mapping, 1)
|
|
assert.True(t, stream.sent[2].InitialSyncComplete)
|
|
|
|
// All 2 acks consumed.
|
|
assert.Equal(t, 2, stream.recvIdx)
|
|
}
|
|
|
|
func TestSendSnapshotSync_SingleBatchNoAckNeeded(t *testing.T) {
|
|
const cluster = "cluster.example.com"
|
|
const batchSize = 100
|
|
const totalServices = 5
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mgr := rpservice.NewMockManager(ctrl)
|
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
|
|
|
s := newSnapshotTestServer(t, batchSize)
|
|
s.serviceManager = mgr
|
|
|
|
// No acks needed — single batch is also the last.
|
|
stream := &syncRecordingStream{}
|
|
conn := &proxyConnection{
|
|
proxyID: "proxy-a",
|
|
address: cluster,
|
|
syncStream: stream,
|
|
}
|
|
|
|
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stream.sent, 1)
|
|
assert.Len(t, stream.sent[0].Mapping, totalServices)
|
|
assert.True(t, stream.sent[0].InitialSyncComplete)
|
|
assert.Equal(t, 0, stream.recvIdx, "no acks should be consumed for a single batch")
|
|
}
|
|
|
|
func TestSendSnapshotSync_EmptySnapshot(t *testing.T) {
|
|
const cluster = "cluster.example.com"
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mgr := rpservice.NewMockManager(ctrl)
|
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
|
|
|
s := newSnapshotTestServer(t, 500)
|
|
s.serviceManager = mgr
|
|
|
|
stream := &syncRecordingStream{}
|
|
conn := &proxyConnection{
|
|
proxyID: "proxy-a",
|
|
address: cluster,
|
|
syncStream: stream,
|
|
}
|
|
|
|
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete")
|
|
assert.Empty(t, stream.sent[0].Mapping)
|
|
assert.True(t, stream.sent[0].InitialSyncComplete)
|
|
}
|
|
|
|
func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) {
|
|
const cluster = "cluster.example.com"
|
|
const batchSize = 2
|
|
const totalServices = 4 // 2 batches → 1 ack needed, but we provide none
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mgr := rpservice.NewMockManager(ctrl)
|
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
|
|
|
s := newSnapshotTestServer(t, batchSize)
|
|
s.serviceManager = mgr
|
|
|
|
// No acks available — Recv will return error.
|
|
stream := &syncRecordingStream{}
|
|
conn := &proxyConnection{
|
|
proxyID: "proxy-a",
|
|
address: cluster,
|
|
syncStream: stream,
|
|
}
|
|
|
|
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "receive ack")
|
|
// First batch should have been sent before the error.
|
|
require.Len(t, stream.sent, 1)
|
|
}
|
|
|
|
func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) {
|
|
const cluster = "cluster.example.com"
|
|
const batchSize = 2
|
|
const totalServices = 4
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mgr := rpservice.NewMockManager(ctrl)
|
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
|
|
|
s := newSnapshotTestServer(t, batchSize)
|
|
s.serviceManager = mgr
|
|
|
|
// Send an init message instead of an ack.
|
|
stream := &syncRecordingStream{
|
|
recvMsgs: []*proto.SyncMappingsRequest{
|
|
{Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}},
|
|
},
|
|
}
|
|
conn := &proxyConnection{
|
|
proxyID: "proxy-a",
|
|
address: cluster,
|
|
syncStream: stream,
|
|
}
|
|
|
|
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "expected ack")
|
|
}
|
|
|
|
func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) {
|
|
// Verify batches are sent strictly sequentially — batch N+1 is not sent
|
|
// until the ack for batch N is received.
|
|
const cluster = "cluster.example.com"
|
|
const batchSize = 2
|
|
const totalServices = 6 // 3 batches, 2 acks
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mgr := rpservice.NewMockManager(ctrl)
|
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
|
|
|
s := newSnapshotTestServer(t, batchSize)
|
|
s.serviceManager = mgr
|
|
|
|
var mu sync.Mutex
|
|
var events []string
|
|
|
|
// Build a stream that logs send/recv events so we can verify ordering.
|
|
ackCh := make(chan struct{}, 2)
|
|
stream := &orderTrackingStream{
|
|
mu: &mu,
|
|
events: &events,
|
|
ackCh: ackCh,
|
|
}
|
|
conn := &proxyConnection{
|
|
proxyID: "proxy-a",
|
|
address: cluster,
|
|
syncStream: stream,
|
|
}
|
|
|
|
// Feed acks asynchronously after a short delay to simulate real proxy.
|
|
go func() {
|
|
for range 2 {
|
|
time.Sleep(10 * time.Millisecond)
|
|
ackCh <- struct{}{}
|
|
}
|
|
}()
|
|
|
|
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
|
require.NoError(t, err)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Expected: send, recv-ack, send, recv-ack, send (last batch, no ack needed).
|
|
require.Len(t, events, 5)
|
|
assert.Equal(t, "send", events[0])
|
|
assert.Equal(t, "recv", events[1])
|
|
assert.Equal(t, "send", events[2])
|
|
assert.Equal(t, "recv", events[3])
|
|
assert.Equal(t, "send", events[4])
|
|
}
|
|
|
|
// orderTrackingStream logs "send" and "recv" events and blocks Recv until
|
|
// an ack is signaled via ackCh.
|
|
type orderTrackingStream struct {
|
|
grpc.ServerStream
|
|
mu *sync.Mutex
|
|
events *[]string
|
|
ackCh chan struct{}
|
|
}
|
|
|
|
func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error {
|
|
s.mu.Lock()
|
|
*s.events = append(*s.events, "send")
|
|
s.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
|
<-s.ackCh
|
|
s.mu.Lock()
|
|
*s.events = append(*s.events, "recv")
|
|
s.mu.Unlock()
|
|
return ackMsg(), nil
|
|
}
|
|
|
|
func (s *orderTrackingStream) Context() context.Context { return context.Background() }
|
|
func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil }
|
|
func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil }
|
|
func (s *orderTrackingStream) SetTrailer(metadata.MD) {}
|
|
func (s *orderTrackingStream) SendMsg(any) error { return nil }
|
|
func (s *orderTrackingStream) RecvMsg(any) error { return nil }
|
|
|
|
func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) {
|
|
const cluster = "cluster.example.com"
|
|
const batchSize = 2
|
|
const totalServices = 4
|
|
const ttl = 100 * time.Millisecond
|
|
const ackDelay = 200 * time.Millisecond
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mgr := rpservice.NewMockManager(ctrl)
|
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
|
|
|
s := newSnapshotTestServer(t, batchSize)
|
|
s.serviceManager = mgr
|
|
s.tokenTTL = ttl
|
|
|
|
// Build a stream that validates tokens immediately on Send, then
|
|
// delays the ack to ensure the next batch's tokens are generated fresh.
|
|
var validateErrs []error
|
|
ackCh := make(chan struct{}, 1)
|
|
stream := &tokenValidatingSyncStream{
|
|
tokenStore: s.tokenStore,
|
|
validateErrs: &validateErrs,
|
|
ackCh: ackCh,
|
|
}
|
|
conn := &proxyConnection{
|
|
proxyID: "proxy-a",
|
|
address: cluster,
|
|
syncStream: stream,
|
|
}
|
|
|
|
go func() {
|
|
// Delay ack so that if tokens were all generated upfront they'd expire.
|
|
time.Sleep(ackDelay)
|
|
ackCh <- struct{}{}
|
|
}()
|
|
|
|
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
|
require.NoError(t, err)
|
|
require.Empty(t, validateErrs,
|
|
"tokens must remain valid: per-batch generation guarantees freshness")
|
|
}
|
|
|
|
type tokenValidatingSyncStream struct {
|
|
grpc.ServerStream
|
|
tokenStore *OneTimeTokenStore
|
|
validateErrs *[]error
|
|
ackCh chan struct{}
|
|
}
|
|
|
|
func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error {
|
|
for _, mapping := range m.Mapping {
|
|
if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil {
|
|
*s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) {
|
|
<-s.ackCh
|
|
return ackMsg(), nil
|
|
}
|
|
|
|
func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() }
|
|
func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil }
|
|
func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil }
|
|
func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {}
|
|
func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil }
|
|
func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil }
|
|
|
|
func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) {
|
|
stream := &syncRecordingStream{}
|
|
conn := &proxyConnection{
|
|
syncStream: stream,
|
|
}
|
|
|
|
resp := &proto.GetMappingUpdateResponse{
|
|
Mapping: []*proto.ProxyMapping{
|
|
{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"},
|
|
},
|
|
InitialSyncComplete: true,
|
|
}
|
|
|
|
err := conn.sendResponse(resp)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stream.sent, 1)
|
|
assert.Len(t, stream.sent[0].Mapping, 1)
|
|
assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id)
|
|
assert.True(t, stream.sent[0].InitialSyncComplete)
|
|
}
|
|
|
|
func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) {
|
|
stream := &recordingStream{}
|
|
conn := &proxyConnection{
|
|
stream: stream,
|
|
}
|
|
|
|
resp := &proto.GetMappingUpdateResponse{
|
|
Mapping: []*proto.ProxyMapping{
|
|
{Id: "svc-2", AccountId: "acct-2"},
|
|
},
|
|
}
|
|
|
|
err := conn.sendResponse(resp)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stream.messages, 1)
|
|
assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id)
|
|
}
|