[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

View File

@@ -32,6 +32,7 @@ import (
proxyauth "github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
nbstatus "github.com/netbirdio/netbird/shared/management/status"
)
type ProxyOIDCConfig struct {
@@ -45,12 +46,6 @@ type ProxyOIDCConfig struct {
KeysLocation string
}
// ClusterInfo contains information about a proxy cluster.
type ClusterInfo struct {
Address string
ConnectedProxies int
}
// ProxyServiceServer implements the ProxyService gRPC server
type ProxyServiceServer struct {
proto.UnimplementedProxyServiceServer
@@ -61,9 +56,9 @@ type ProxyServiceServer struct {
// Manager for access logs
accessLogManager accesslogs.Manager
mu sync.RWMutex
// Manager for reverse proxy operations
serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController proxy.Controller
@@ -84,23 +79,26 @@ type ProxyServiceServer struct {
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
cancel context.CancelFunc
}
const pkceVerifierTTL = 10 * time.Minute
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
address string
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
proxyID string
address string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
@@ -109,6 +107,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
cancel: cancel,
}
go s.cleanupStaleProxies(ctx)
return s
@@ -130,11 +129,22 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
}
}
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.cancel()
}
// SetServiceManager sets the service manager. Must be called before serving.
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.mu.Lock()
defer s.mu.Unlock()
s.serviceManager = manager
}
// SetProxyController sets the proxy controller. Must be called before serving.
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
s.mu.Lock()
defer s.mu.Unlock()
s.proxyController = proxyController
}
@@ -157,12 +167,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
address: proxyAddress,
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
proxyID: proxyID,
address: proxyAddress,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
@@ -231,29 +242,18 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent.
// Only entries matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return fmt.Errorf("get services from store: %w", err)
}
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
var filtered []*rpservice.Service
for _, service := range services {
if !service.Enabled {
continue
}
if service.ProxyCluster == "" || service.ProxyCluster != conn.address {
continue
}
filtered = append(filtered, service)
mappings, err := s.snapshotServiceMappings(ctx, conn)
if err != nil {
return err
}
if len(filtered) == 0 {
if len(mappings) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true,
}); err != nil {
@@ -262,9 +262,30 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return nil
}
for i, service := range filtered {
// Generate one-time authentication token for each service in the snapshot
// Tokens are not persistent on the proxy, so we need to generate new ones on reconnection
for i, m := range mappings {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{m},
InitialSyncComplete: i == len(mappings)-1,
}); err != nil {
return fmt.Errorf("send proxy mapping: %w", err)
}
}
return nil
}
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services from store: %w", err)
}
var mappings []*proto.ProxyMapping
for _, service := range services {
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
continue
}
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
if err != nil {
log.WithFields(log.Fields{
@@ -274,25 +295,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
continue
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
service.ToProtoMapping(
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
token,
s.GetOIDCValidationConfig(),
),
},
InitialSyncComplete: i == len(filtered)-1,
}); err != nil {
log.WithFields(log.Fields{
"domain": service.Domain,
"account": service.AccountID,
}).WithError(err).Error("failed to send proxy mapping")
return fmt.Errorf("send proxy mapping: %w", err)
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
mappings = append(mappings, m)
}
return nil
return mappings, nil
}
// isProxyAddressValid validates a proxy address
@@ -305,8 +311,8 @@ func isProxyAddressValid(addr string) bool {
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
for {
select {
case msg := <-conn.sendChan:
if err := conn.stream.Send(msg); err != nil {
case resp := <-conn.sendChan:
if err := conn.stream.Send(resp); err != nil {
errChan <- err
return
}
@@ -361,12 +367,12 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
log.Debugf("Broadcasting service update to all connected proxy servers")
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
msg := s.perProxyMessage(update, conn.proxyID)
if msg == nil {
resp := s.perProxyMessage(update, conn.proxyID)
if resp == nil {
return true
}
select {
case conn.sendChan <- msg:
case conn.sendChan <- resp:
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
@@ -495,9 +501,40 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
Auth: m.Auth,
PassHostHeader: m.PassHostHeader,
RewriteRedirects: m.RewriteRedirects,
Mode: m.Mode,
ListenPort: m.ListenPort,
}
}
// ClusterSupportsCustomPorts returns whether any connected proxy in the given
// cluster reports custom port support. Returns nil if no proxy has reported
// capabilities (old proxies that predate the field).
func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool {
if s.proxyController == nil {
return nil
}
var hasCapabilities bool
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
connVal, ok := s.connectedProxies.Load(pid)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil {
continue
}
if *conn.capabilities.SupportsCustomPorts {
return ptr(true)
}
hasCapabilities = true
}
if hasCapabilities {
return ptr(false)
}
return nil
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
@@ -585,7 +622,7 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
return token, nil
}
// SendStatusUpdate handles status updates from proxy clients
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
accountID := req.GetAccountId()
serviceID := req.GetServiceId()
@@ -604,6 +641,17 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required")
}
internalStatus := protoStatusToInternal(protoStatus)
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
sErr, isNbErr := nbstatus.FromError(err)
if isNbErr && sErr.Type() == nbstatus.NotFound {
return nil, status.Errorf(codes.NotFound, "service %s not found", serviceID)
}
log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
}
if certificateIssued {
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
@@ -615,13 +663,6 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
}).Info("Certificate issued timestamp updated")
}
internalStatus := protoStatusToInternal(protoStatus)
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
}
log.WithFields(log.Fields{
"service_id": serviceID,
"account_id": accountID,
@@ -631,7 +672,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
return &proto.SendStatusUpdateResponse{}, nil
}
// protoStatusToInternal maps proto status to internal status
// protoStatusToInternal maps proto status to internal service status.
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
switch protoStatus {
case proto.ProxyStatus_PROXY_STATUS_PENDING:
@@ -1061,3 +1102,5 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
return fmt.Errorf("user not in allowed groups")
}
func ptr[T any](v T) *T { return &v }