Merge branch 'main' into proto-ipv6-overlay

This commit is contained in:
Viktor Liu
2026-05-06 12:22:28 +02:00
17 changed files with 1491 additions and 214 deletions

View File

@@ -11,6 +11,8 @@ import (
"fmt"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -82,11 +84,40 @@ type ProxyServiceServer struct {
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
// tokenTTL is the lifetime of one-time tokens generated for proxy
// authentication. Defaults to defaultProxyTokenTTL when zero.
tokenTTL time.Duration
// snapshotBatchSize is the number of mappings per gRPC message during
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
snapshotBatchSize int
cancel context.CancelFunc
}
const pkceVerifierTTL = 10 * time.Minute
const defaultProxyTokenTTL = 5 * time.Minute
const defaultSnapshotBatchSize = 500
func snapshotBatchSizeFromEnv() int {
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultSnapshotBatchSize
}
// proxyTokenTTL returns the configured token TTL or the default when unset.
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
if s.tokenTTL > 0 {
return s.tokenTTL
}
return defaultProxyTokenTTL
}
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
@@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
cancel: cancel,
}
go s.cleanupStaleProxies(ctx)
@@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
@@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.CompareAndDelete(proxyID, conn)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
cancel()
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
if err := s.sendSnapshot(ctx, conn); err != nil {
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
}
cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
}
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"session_id": sessionID,
@@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
if err := s.sendSnapshot(ctx, conn); err != nil {
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.heartbeat(connCtx, proxyRecord)
select {
@@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return err
}
// Send mappings in batches to reduce per-message gRPC overhead while
// staying well within the default 4 MB message size limit.
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
}
if len(mappings) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true,
}); err != nil {
return fmt.Errorf("send snapshot completion: %w", err)
}
return nil
}
for i, m := range mappings {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{m},
InitialSyncComplete: i == len(mappings)-1,
}); err != nil {
return fmt.Errorf("send proxy mapping: %w", err)
}
}
return nil
@@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
continue
}
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
if err != nil {
log.WithFields(log.Fields{
"service": service.Name,
"account": service.AccountID,
}).WithError(err).Error("failed to generate auth token for snapshot")
continue
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
@@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
return true
}
select {
case conn.sendChan <- resp:
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
}
return true
})
@@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
continue
}
select {
case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
}
}
}
@@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped).
// Returns nil if token generation fails; the caller must disconnect the
// proxy so it can resync via a fresh snapshot on reconnect.
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping {
@@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
continue
}
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil

View File

@@ -0,0 +1,174 @@
package grpc
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
// recordingStream captures all messages sent via Send so tests can inspect
// batching behaviour without a real gRPC transport.
type recordingStream struct {
grpc.ServerStream
messages []*proto.GetMappingUpdateResponse
}
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
s.messages = append(s.messages, m)
return nil
}
func (s *recordingStream) Context() context.Context { return context.Background() }
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
func (s *recordingStream) SetTrailer(metadata.MD) {}
func (s *recordingStream) SendMsg(any) error { return nil }
func (s *recordingStream) RecvMsg(any) error { return nil }
// makeServices creates n enabled services assigned to the given cluster.
func makeServices(n int, cluster string) []*rpservice.Service {
services := make([]*rpservice.Service, n)
for i := range n {
services[i] = &rpservice.Service{
ID: fmt.Sprintf("svc-%d", i),
AccountID: "acct-1",
Name: fmt.Sprintf("svc-%d", i),
Domain: fmt.Sprintf("svc-%d.example.com", i),
ProxyCluster: cluster,
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
},
}
}
return services
}
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
t.Helper()
s := &ProxyServiceServer{
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
snapshotBatchSize: batchSize,
}
s.SetProxyController(newTestProxyController())
return s
}
func TestSendSnapshot_BatchesMappings(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 7 // 3 + 3 + 1
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
stream: stream,
}
err := s.sendSnapshot(context.Background(), conn)
require.NoError(t, err)
// Expect ceil(7/3) = 3 messages
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
assert.Len(t, stream.messages[1].Mapping, 3)
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
assert.Len(t, stream.messages[2].Mapping, 1)
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
// Verify all service IDs are present exactly once
seen := make(map[string]bool)
for _, msg := range stream.messages {
for _, m := range msg.Mapping {
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
seen[m.Id] = true
}
}
assert.Len(t, seen, totalServices)
}
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 6 // exactly 2 batches
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 2)
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete)
assert.Len(t, stream.messages[1].Mapping, 3)
assert.True(t, stream.messages[1].InitialSyncComplete)
}
func TestSendSnapshot_SingleBatch(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 100
const totalServices = 5
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
assert.Len(t, stream.messages[0].Mapping, totalServices)
assert.True(t, stream.messages[0].InitialSyncComplete)
}
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
const cluster = "cluster.example.com"
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
s := newSnapshotTestServer(t, 500)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
assert.Empty(t, stream.messages[0].Mapping)
assert.True(t, stream.messages[0].InitialSyncComplete)
}

View File

@@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10)
ctx, cancel := context.WithCancel(context.Background())
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
capabilities: caps,
sendChan: ch,
ctx: ctx,
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)

View File

@@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Modify one character in the secret part
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
replacement := "X"
if plainToken[5] == 'X' {
replacement = "Y"
}
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
err = ValidateInviteToken(modifiedToken)
require.Error(t, err)
}