mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
@@ -32,6 +32,7 @@ import (
|
||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type ProxyOIDCConfig struct {
|
||||
@@ -45,12 +46,6 @@ type ProxyOIDCConfig struct {
|
||||
KeysLocation string
|
||||
}
|
||||
|
||||
// ClusterInfo contains information about a proxy cluster.
|
||||
type ClusterInfo struct {
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
@@ -61,9 +56,9 @@ type ProxyServiceServer struct {
|
||||
// Manager for access logs
|
||||
accessLogManager accesslogs.Manager
|
||||
|
||||
mu sync.RWMutex
|
||||
// Manager for reverse proxy operations
|
||||
serviceManager rpservice.Manager
|
||||
|
||||
// ProxyController for service updates and cluster management
|
||||
proxyController proxy.Controller
|
||||
|
||||
@@ -84,23 +79,26 @@ type ProxyServiceServer struct {
|
||||
|
||||
// Store for PKCE verifiers
|
||||
pkceVerifierStore *PKCEVerifierStore
|
||||
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
const pkceVerifierTTL = 10 * time.Minute
|
||||
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
address string
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
proxyID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
oidcConfig: oidcConfig,
|
||||
@@ -109,6 +107,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
cancel: cancel,
|
||||
}
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
return s
|
||||
@@ -130,11 +129,22 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops background goroutines.
|
||||
func (s *ProxyServiceServer) Close() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// SetServiceManager sets the service manager. Must be called before serving.
|
||||
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.serviceManager = manager
|
||||
}
|
||||
|
||||
// SetProxyController sets the proxy controller. Must be called before serving.
|
||||
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
@@ -157,12 +167,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: proxyAddress,
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
proxyID: proxyID,
|
||||
address: proxyAddress,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
@@ -231,29 +242,18 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only services matching the proxy's cluster address are sent.
|
||||
// Only entries matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
|
||||
var filtered []*rpservice.Service
|
||||
for _, service := range services {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
if service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, service)
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
if len(mappings) == 0 {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
@@ -262,9 +262,30 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, service := range filtered {
|
||||
// Generate one-time authentication token for each service in the snapshot
|
||||
// Tokens are not persistent on the proxy, so we need to generate new ones on reconnection
|
||||
for i, m := range mappings {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{m},
|
||||
InitialSyncComplete: i == len(mappings)-1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send proxy mapping: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
var mappings []*proto.ProxyMapping
|
||||
for _, service := range services {
|
||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
@@ -274,25 +295,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
service.ToProtoMapping(
|
||||
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
token,
|
||||
s.GetOIDCValidationConfig(),
|
||||
),
|
||||
},
|
||||
InitialSyncComplete: i == len(filtered)-1,
|
||||
}); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": service.Domain,
|
||||
"account": service.AccountID,
|
||||
}).WithError(err).Error("failed to send proxy mapping")
|
||||
return fmt.Errorf("send proxy mapping: %w", err)
|
||||
}
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
mappings = append(mappings, m)
|
||||
}
|
||||
|
||||
return nil
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// isProxyAddressValid validates a proxy address
|
||||
@@ -305,8 +311,8 @@ func isProxyAddressValid(addr string) bool {
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
select {
|
||||
case msg := <-conn.sendChan:
|
||||
if err := conn.stream.Send(msg); err != nil {
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.stream.Send(resp); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -361,12 +367,12 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
conn := value.(*proxyConnection)
|
||||
msg := s.perProxyMessage(update, conn.proxyID)
|
||||
if msg == nil {
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
if resp == nil {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
case conn.sendChan <- resp:
|
||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||
@@ -495,9 +501,40 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
Auth: m.Auth,
|
||||
PassHostHeader: m.PassHostHeader,
|
||||
RewriteRedirects: m.RewriteRedirects,
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
}
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any connected proxy in the given
|
||||
// cluster reports custom port support. Returns nil if no proxy has reported
|
||||
// capabilities (old proxies that predate the field).
|
||||
func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool {
|
||||
if s.proxyController == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var hasCapabilities bool
|
||||
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
|
||||
connVal, ok := s.connectedProxies.Load(pid)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil {
|
||||
continue
|
||||
}
|
||||
if *conn.capabilities.SupportsCustomPorts {
|
||||
return ptr(true)
|
||||
}
|
||||
hasCapabilities = true
|
||||
}
|
||||
if hasCapabilities {
|
||||
return ptr(false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
@@ -585,7 +622,7 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
accountID := req.GetAccountId()
|
||||
serviceID := req.GetServiceId()
|
||||
@@ -604,6 +641,17 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required")
|
||||
}
|
||||
|
||||
internalStatus := protoStatusToInternal(protoStatus)
|
||||
|
||||
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
sErr, isNbErr := nbstatus.FromError(err)
|
||||
if isNbErr && sErr.Type() == nbstatus.NotFound {
|
||||
return nil, status.Errorf(codes.NotFound, "service %s not found", serviceID)
|
||||
}
|
||||
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
||||
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
||||
}
|
||||
|
||||
if certificateIssued {
|
||||
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
|
||||
@@ -615,13 +663,6 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
}).Info("Certificate issued timestamp updated")
|
||||
}
|
||||
|
||||
internalStatus := protoStatusToInternal(protoStatus)
|
||||
|
||||
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
||||
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": serviceID,
|
||||
"account_id": accountID,
|
||||
@@ -631,7 +672,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
// protoStatusToInternal maps proto status to internal status
|
||||
// protoStatusToInternal maps proto status to internal service status.
|
||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
switch protoStatus {
|
||||
case proto.ProxyStatus_PROXY_STATUS_PENDING:
|
||||
@@ -1061,3 +1102,5 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
|
||||
|
||||
return fmt.Errorf("user not in allowed groups")
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T { return &v }
|
||||
|
||||
Reference in New Issue
Block a user