mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
@@ -2,6 +2,7 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -39,23 +40,38 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage)
|
||||
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||
}
|
||||
|
||||
if exposeReq.Port > 65535 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "port out of range: %d", exposeReq.Port)
|
||||
}
|
||||
if exposeReq.ListenPort > 65535 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "listen_port out of range: %d", exposeReq.ListenPort)
|
||||
}
|
||||
|
||||
mode, err := exposeProtocolToString(exposeReq.Protocol)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
|
||||
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
|
||||
NamePrefix: exposeReq.NamePrefix,
|
||||
Port: int(exposeReq.Port),
|
||||
Protocol: exposeProtocolToString(exposeReq.Protocol),
|
||||
Domain: exposeReq.Domain,
|
||||
Pin: exposeReq.Pin,
|
||||
Password: exposeReq.Password,
|
||||
UserGroups: exposeReq.UserGroups,
|
||||
NamePrefix: exposeReq.NamePrefix,
|
||||
Port: uint16(exposeReq.Port), //nolint:gosec // validated above
|
||||
Mode: mode,
|
||||
TargetProtocol: exposeTargetProtocol(exposeReq.Protocol),
|
||||
Domain: exposeReq.Domain,
|
||||
Pin: exposeReq.Pin,
|
||||
Password: exposeReq.Password,
|
||||
UserGroups: exposeReq.UserGroups,
|
||||
ListenPort: uint16(exposeReq.ListenPort), //nolint:gosec // validated above
|
||||
})
|
||||
if err != nil {
|
||||
return nil, mapExposeError(ctx, err)
|
||||
}
|
||||
|
||||
return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{
|
||||
ServiceName: created.ServiceName,
|
||||
ServiceUrl: created.ServiceURL,
|
||||
Domain: created.Domain,
|
||||
ServiceName: created.ServiceName,
|
||||
ServiceUrl: created.ServiceURL,
|
||||
Domain: created.Domain,
|
||||
PortAutoAssigned: created.PortAutoAssigned,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -77,7 +93,12 @@ func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (
|
||||
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||
}
|
||||
|
||||
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil {
|
||||
serviceID, err := s.resolveServiceID(ctx, renewReq.Domain)
|
||||
if err != nil {
|
||||
return nil, mapExposeError(ctx, err)
|
||||
}
|
||||
|
||||
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil {
|
||||
return nil, mapExposeError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -102,7 +123,12 @@ func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*
|
||||
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||
}
|
||||
|
||||
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil {
|
||||
serviceID, err := s.resolveServiceID(ctx, stopReq.Domain)
|
||||
if err != nil {
|
||||
return nil, mapExposeError(ctx, err)
|
||||
}
|
||||
|
||||
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil {
|
||||
return nil, mapExposeError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -180,13 +206,46 @@ func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
|
||||
s.reverseProxyManager = mgr
|
||||
}
|
||||
|
||||
func exposeProtocolToString(p proto.ExposeProtocol) string {
|
||||
// resolveServiceID looks up the service by its globally unique domain.
|
||||
func (s *Server) resolveServiceID(ctx context.Context, domain string) (string, error) {
|
||||
if domain == "" {
|
||||
return "", status.Errorf(codes.InvalidArgument, "domain is required")
|
||||
}
|
||||
|
||||
svc, err := s.accountManager.GetStore().GetServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return svc.ID, nil
|
||||
}
|
||||
|
||||
func exposeProtocolToString(p proto.ExposeProtocol) (string, error) {
|
||||
switch p {
|
||||
case proto.ExposeProtocol_EXPOSE_HTTP:
|
||||
return "http"
|
||||
case proto.ExposeProtocol_EXPOSE_HTTPS:
|
||||
return "https"
|
||||
case proto.ExposeProtocol_EXPOSE_HTTP, proto.ExposeProtocol_EXPOSE_HTTPS:
|
||||
return "http", nil
|
||||
case proto.ExposeProtocol_EXPOSE_TCP:
|
||||
return "tcp", nil
|
||||
case proto.ExposeProtocol_EXPOSE_UDP:
|
||||
return "udp", nil
|
||||
case proto.ExposeProtocol_EXPOSE_TLS:
|
||||
return "tls", nil
|
||||
default:
|
||||
return "http"
|
||||
return "", fmt.Errorf("unsupported expose protocol: %v", p)
|
||||
}
|
||||
}
|
||||
|
||||
// exposeTargetProtocol returns the target protocol for the given expose protocol.
|
||||
// For HTTP mode, this is http or https (the scheme used to connect to the backend).
|
||||
// For L4 modes, this is tcp or udp (the transport used to connect to the backend).
|
||||
func exposeTargetProtocol(p proto.ExposeProtocol) string {
|
||||
switch p {
|
||||
case proto.ExposeProtocol_EXPOSE_HTTPS:
|
||||
return rpservice.TargetProtoHTTPS
|
||||
case proto.ExposeProtocol_EXPOSE_TCP, proto.ExposeProtocol_EXPOSE_TLS:
|
||||
return rpservice.TargetProtoTCP
|
||||
case proto.ExposeProtocol_EXPOSE_UDP:
|
||||
return rpservice.TargetProtoUDP
|
||||
default:
|
||||
return rpservice.TargetProtoHTTP
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type ProxyOIDCConfig struct {
|
||||
@@ -45,12 +46,6 @@ type ProxyOIDCConfig struct {
|
||||
KeysLocation string
|
||||
}
|
||||
|
||||
// ClusterInfo contains information about a proxy cluster.
|
||||
type ClusterInfo struct {
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
@@ -61,9 +56,9 @@ type ProxyServiceServer struct {
|
||||
// Manager for access logs
|
||||
accessLogManager accesslogs.Manager
|
||||
|
||||
mu sync.RWMutex
|
||||
// Manager for reverse proxy operations
|
||||
serviceManager rpservice.Manager
|
||||
|
||||
// ProxyController for service updates and cluster management
|
||||
proxyController proxy.Controller
|
||||
|
||||
@@ -84,23 +79,26 @@ type ProxyServiceServer struct {
|
||||
|
||||
// Store for PKCE verifiers
|
||||
pkceVerifierStore *PKCEVerifierStore
|
||||
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
const pkceVerifierTTL = 10 * time.Minute
|
||||
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
address string
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
proxyID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
oidcConfig: oidcConfig,
|
||||
@@ -109,6 +107,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
cancel: cancel,
|
||||
}
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
return s
|
||||
@@ -130,11 +129,22 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops background goroutines.
|
||||
func (s *ProxyServiceServer) Close() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// SetServiceManager sets the service manager. Must be called before serving.
|
||||
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.serviceManager = manager
|
||||
}
|
||||
|
||||
// SetProxyController sets the proxy controller. Must be called before serving.
|
||||
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
@@ -157,12 +167,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: proxyAddress,
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
proxyID: proxyID,
|
||||
address: proxyAddress,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
@@ -231,29 +242,18 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only services matching the proxy's cluster address are sent.
|
||||
// Only entries matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
|
||||
var filtered []*rpservice.Service
|
||||
for _, service := range services {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
if service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, service)
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
if len(mappings) == 0 {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
@@ -262,9 +262,30 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, service := range filtered {
|
||||
// Generate one-time authentication token for each service in the snapshot
|
||||
// Tokens are not persistent on the proxy, so we need to generate new ones on reconnection
|
||||
for i, m := range mappings {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{m},
|
||||
InitialSyncComplete: i == len(mappings)-1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send proxy mapping: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
var mappings []*proto.ProxyMapping
|
||||
for _, service := range services {
|
||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
@@ -274,25 +295,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
service.ToProtoMapping(
|
||||
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
token,
|
||||
s.GetOIDCValidationConfig(),
|
||||
),
|
||||
},
|
||||
InitialSyncComplete: i == len(filtered)-1,
|
||||
}); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": service.Domain,
|
||||
"account": service.AccountID,
|
||||
}).WithError(err).Error("failed to send proxy mapping")
|
||||
return fmt.Errorf("send proxy mapping: %w", err)
|
||||
}
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
mappings = append(mappings, m)
|
||||
}
|
||||
|
||||
return nil
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// isProxyAddressValid validates a proxy address
|
||||
@@ -305,8 +311,8 @@ func isProxyAddressValid(addr string) bool {
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
select {
|
||||
case msg := <-conn.sendChan:
|
||||
if err := conn.stream.Send(msg); err != nil {
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.stream.Send(resp); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -361,12 +367,12 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
conn := value.(*proxyConnection)
|
||||
msg := s.perProxyMessage(update, conn.proxyID)
|
||||
if msg == nil {
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
if resp == nil {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
case conn.sendChan <- resp:
|
||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||
@@ -495,9 +501,40 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
Auth: m.Auth,
|
||||
PassHostHeader: m.PassHostHeader,
|
||||
RewriteRedirects: m.RewriteRedirects,
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
}
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any connected proxy in the given
|
||||
// cluster reports custom port support. Returns nil if no proxy has reported
|
||||
// capabilities (old proxies that predate the field).
|
||||
func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool {
|
||||
if s.proxyController == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var hasCapabilities bool
|
||||
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
|
||||
connVal, ok := s.connectedProxies.Load(pid)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil {
|
||||
continue
|
||||
}
|
||||
if *conn.capabilities.SupportsCustomPorts {
|
||||
return ptr(true)
|
||||
}
|
||||
hasCapabilities = true
|
||||
}
|
||||
if hasCapabilities {
|
||||
return ptr(false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
@@ -585,7 +622,7 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
accountID := req.GetAccountId()
|
||||
serviceID := req.GetServiceId()
|
||||
@@ -604,6 +641,17 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required")
|
||||
}
|
||||
|
||||
internalStatus := protoStatusToInternal(protoStatus)
|
||||
|
||||
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
sErr, isNbErr := nbstatus.FromError(err)
|
||||
if isNbErr && sErr.Type() == nbstatus.NotFound {
|
||||
return nil, status.Errorf(codes.NotFound, "service %s not found", serviceID)
|
||||
}
|
||||
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
||||
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
||||
}
|
||||
|
||||
if certificateIssued {
|
||||
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
|
||||
@@ -615,13 +663,6 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
}).Info("Certificate issued timestamp updated")
|
||||
}
|
||||
|
||||
internalStatus := protoStatusToInternal(protoStatus)
|
||||
|
||||
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
||||
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": serviceID,
|
||||
"account_id": accountID,
|
||||
@@ -631,7 +672,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
// protoStatusToInternal maps proto status to internal status
|
||||
// protoStatusToInternal maps proto status to internal service status.
|
||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
switch protoStatus {
|
||||
case proto.ProxyStatus_PROXY_STATUS_PENDING:
|
||||
@@ -1061,3 +1102,5 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
|
||||
|
||||
return fmt.Errorf("user not in allowed groups")
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T { return &v }
|
||||
|
||||
@@ -53,6 +53,10 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
|
||||
return ptr(true)
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@@ -70,11 +74,17 @@ func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string
|
||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||
// and returns the channel where messages will be received.
|
||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
|
||||
return registerFakeProxyWithCaps(s, proxyID, clusterAddr, nil)
|
||||
}
|
||||
|
||||
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
||||
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
sendChan: ch,
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
capabilities: caps,
|
||||
sendChan: ch,
|
||||
}
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
|
||||
@@ -83,15 +93,29 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
||||
return ch
|
||||
}
|
||||
|
||||
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||
// drainMapping drains a single ProxyMapping from the channel.
|
||||
func drainMapping(ch chan *proto.GetMappingUpdateResponse) *proto.ProxyMapping {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
return msg
|
||||
case resp := <-ch:
|
||||
if len(resp.Mapping) > 0 {
|
||||
return resp.Mapping[0]
|
||||
}
|
||||
return nil
|
||||
case <-time.After(time.Second):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// drainEmpty checks if a channel has no message within timeout.
|
||||
func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool {
|
||||
select {
|
||||
case <-ch:
|
||||
return false
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||
@@ -129,10 +153,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
|
||||
tokens := make([]string, numProxies)
|
||||
for i, ch := range channels {
|
||||
resp := drainChannel(ch)
|
||||
require.NotNil(t, resp, "proxy %d should receive a message", i)
|
||||
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
|
||||
msg := resp.Mapping[0]
|
||||
msg := drainMapping(ch)
|
||||
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
||||
assert.Equal(t, mapping.Domain, msg.Domain)
|
||||
assert.Equal(t, mapping.Id, msg.Id)
|
||||
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||
@@ -181,16 +203,14 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
|
||||
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
||||
|
||||
resp1 := drainChannel(ch1)
|
||||
resp2 := drainChannel(ch2)
|
||||
require.NotNil(t, resp1)
|
||||
require.NotNil(t, resp2)
|
||||
require.Len(t, resp1.Mapping, 1)
|
||||
require.Len(t, resp2.Mapping, 1)
|
||||
msg1 := drainMapping(ch1)
|
||||
msg2 := drainMapping(ch2)
|
||||
require.NotNil(t, msg1)
|
||||
require.NotNil(t, msg2)
|
||||
|
||||
// Delete operations should not generate tokens
|
||||
assert.Empty(t, resp1.Mapping[0].AuthToken)
|
||||
assert.Empty(t, resp2.Mapping[0].AuthToken)
|
||||
assert.Empty(t, msg1.AuthToken)
|
||||
assert.Empty(t, msg2.AuthToken)
|
||||
}
|
||||
|
||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
@@ -224,15 +244,10 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
|
||||
s.SendServiceUpdate(update)
|
||||
|
||||
resp1 := drainChannel(ch1)
|
||||
resp2 := drainChannel(ch2)
|
||||
require.NotNil(t, resp1)
|
||||
require.NotNil(t, resp2)
|
||||
require.Len(t, resp1.Mapping, 1)
|
||||
require.Len(t, resp2.Mapping, 1)
|
||||
|
||||
msg1 := resp1.Mapping[0]
|
||||
msg2 := resp2.Mapping[0]
|
||||
msg1 := drainMapping(ch1)
|
||||
msg2 := drainMapping(ch2)
|
||||
require.NotNil(t, msg1)
|
||||
require.NotNil(t, msg2)
|
||||
|
||||
assert.NotEmpty(t, msg1.AuthToken)
|
||||
assert.NotEmpty(t, msg2.AuthToken)
|
||||
@@ -324,3 +339,314 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state signature")
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
|
||||
// Proxy A supports custom ports.
|
||||
chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
|
||||
// Proxy B does NOT support custom ports (shared cloud proxy).
|
||||
chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// TLS passthrough works on all proxies regardless of custom port support.
|
||||
tlsMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-tls",
|
||||
AccountId: "account-1",
|
||||
Domain: "db.example.com",
|
||||
Mode: "tls",
|
||||
ListenPort: 8443,
|
||||
Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster)
|
||||
|
||||
msgA := drainMapping(chA)
|
||||
msgB := drainMapping(chB)
|
||||
assert.NotNil(t, msgA, "proxy-a should receive TLS mapping")
|
||||
assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)")
|
||||
|
||||
// Send an HTTP mapping: both should receive it.
|
||||
httpMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-http",
|
||||
AccountId: "account-1",
|
||||
Domain: "app.example.com",
|
||||
Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:80"}},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping, cluster)
|
||||
|
||||
msgA = drainMapping(chA)
|
||||
msgB = drainMapping(chB)
|
||||
assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping")
|
||||
assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping")
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
|
||||
chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
|
||||
tlsMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-tls",
|
||||
AccountId: "account-1",
|
||||
Domain: "db.example.com",
|
||||
Mode: "tls",
|
||||
Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster)
|
||||
|
||||
msg := drainMapping(chShared)
|
||||
assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support")
|
||||
}
|
||||
|
||||
// TestServiceModifyNotifications exercises every possible modification
|
||||
// scenario for an existing service, verifying the correct update types
|
||||
// reach the correct clusters.
|
||||
func TestServiceModifyNotifications(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) {
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
chs := map[string]chan *proto.GetMappingUpdateResponse{
|
||||
"cluster-a": registerFakeProxyWithCaps(s, "proxy-a", "cluster-a", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}),
|
||||
"cluster-b": registerFakeProxyWithCaps(s, "proxy-b", "cluster-b", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}),
|
||||
}
|
||||
return s, chs
|
||||
}
|
||||
|
||||
httpMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping {
|
||||
return &proto.ProxyMapping{
|
||||
Type: updateType,
|
||||
Id: "svc-1",
|
||||
AccountId: "acct-1",
|
||||
Domain: "app.example.com",
|
||||
Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:8080"}},
|
||||
}
|
||||
}
|
||||
|
||||
tlsOnlyMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping {
|
||||
return &proto.ProxyMapping{
|
||||
Type: updateType,
|
||||
Id: "svc-1",
|
||||
AccountId: "acct-1",
|
||||
Domain: "app.example.com",
|
||||
Mode: "tls",
|
||||
ListenPort: 8443,
|
||||
Path: []*proto.PathMapping{{Target: "10.0.0.1:443"}},
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("targets changed sends MODIFIED to same cluster", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg, "cluster-a should receive update")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type)
|
||||
assert.NotEmpty(t, msg.AuthToken, "MODIFIED should include token")
|
||||
assert.True(t, drainEmpty(chs["cluster-b"]), "cluster-b should not receive update")
|
||||
})
|
||||
|
||||
t.Run("auth config changed sends MODIFIED", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
|
||||
mapping.Auth = &proto.Authentication{Password: true, Pin: true}
|
||||
s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg)
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type)
|
||||
assert.True(t, msg.Auth.Password)
|
||||
assert.True(t, msg.Auth.Pin)
|
||||
})
|
||||
|
||||
t.Run("HTTP to TLS transition sends MODIFIED with TLS config", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg)
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type)
|
||||
assert.Equal(t, "tls", msg.Mode, "mode should be tls")
|
||||
assert.Equal(t, int32(8443), msg.ListenPort)
|
||||
assert.Len(t, msg.Path, 1, "should have one path entry with target address")
|
||||
assert.Equal(t, "10.0.0.1:443", msg.Path[0].Target)
|
||||
})
|
||||
|
||||
t.Run("TLS to HTTP transition sends MODIFIED without TLS", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg)
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type)
|
||||
assert.Empty(t, msg.Mode, "mode should be empty for HTTP")
|
||||
assert.True(t, len(msg.Path) > 0)
|
||||
})
|
||||
|
||||
t.Run("TLS port changed sends MODIFIED with new port", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
|
||||
mapping.ListenPort = 9443
|
||||
s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg)
|
||||
assert.Equal(t, int32(9443), msg.ListenPort)
|
||||
})
|
||||
|
||||
t.Run("disable sends REMOVED to cluster", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
// Manager sends Delete when service is disabled
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg)
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.Type)
|
||||
assert.Empty(t, msg.AuthToken, "DELETE should not have token")
|
||||
})
|
||||
|
||||
t.Run("enable sends CREATED to cluster", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg)
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msg.Type)
|
||||
assert.NotEmpty(t, msg.AuthToken)
|
||||
})
|
||||
|
||||
t.Run("domain change with cluster change sends DELETE to old CREATE to new", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
// This is the pattern the manager produces:
|
||||
// 1. DELETE on old cluster
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a")
|
||||
// 2. CREATE on new cluster
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-b")
|
||||
|
||||
msgA := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msgA, "old cluster should receive DELETE")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgA.Type)
|
||||
|
||||
msgB := drainMapping(chs["cluster-b"])
|
||||
require.NotNil(t, msgB, "new cluster should receive CREATE")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgB.Type)
|
||||
assert.NotEmpty(t, msgB.AuthToken)
|
||||
})
|
||||
|
||||
t.Run("domain change same cluster sends DELETE then CREATE", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
// Domain changes within same cluster: manager sends DELETE (old domain) + CREATE (new domain).
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a")
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a")
|
||||
|
||||
msgDel := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msgDel, "same cluster should receive DELETE")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgDel.Type)
|
||||
|
||||
msgCreate := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msgCreate, "same cluster should receive CREATE")
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgCreate.Type)
|
||||
assert.NotEmpty(t, msgCreate.AuthToken)
|
||||
})
|
||||
|
||||
t.Run("TLS passthrough sent to all proxies", func(t *testing.T) {
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
const cluster = "proxy.example.com"
|
||||
chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
|
||||
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
|
||||
// TLS passthrough works on all proxies regardless of custom port support
|
||||
s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster)
|
||||
|
||||
msgModern := drainMapping(chModern)
|
||||
require.NotNil(t, msgModern, "modern proxy receives TLS update")
|
||||
assert.Equal(t, "tls", msgModern.Mode)
|
||||
|
||||
msgLegacy := drainMapping(chLegacy)
|
||||
assert.NotNil(t, msgLegacy, "legacy proxy should also receive TLS passthrough")
|
||||
})
|
||||
|
||||
t.Run("TLS on default port NOT filtered for legacy proxy", func(t *testing.T) {
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
const cluster = "proxy.example.com"
|
||||
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
|
||||
mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
|
||||
mapping.ListenPort = 0 // default port
|
||||
s.SendServiceUpdateToCluster(ctx, mapping, cluster)
|
||||
|
||||
msgLegacy := drainMapping(chLegacy)
|
||||
assert.NotNil(t, msgLegacy, "legacy proxy should receive TLS on default port")
|
||||
})
|
||||
|
||||
t.Run("passthrough and rewrite flags propagated", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
|
||||
mapping.PassHostHeader = true
|
||||
mapping.RewriteRedirects = true
|
||||
s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg)
|
||||
assert.True(t, msg.PassHostHeader)
|
||||
assert.True(t, msg.RewriteRedirects)
|
||||
})
|
||||
|
||||
t.Run("multiple paths propagated in MODIFIED", func(t *testing.T) {
|
||||
s, chs := newServer()
|
||||
mapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED,
|
||||
Id: "svc-multi",
|
||||
AccountId: "acct-1",
|
||||
Domain: "multi.example.com",
|
||||
Path: []*proto.PathMapping{
|
||||
{Path: "/", Target: "http://10.0.0.1:8080"},
|
||||
{Path: "/api", Target: "http://10.0.0.2:9090"},
|
||||
{Path: "/ws", Target: "http://10.0.0.3:3000"},
|
||||
},
|
||||
}
|
||||
s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a")
|
||||
|
||||
msg := drainMapping(chs["cluster-a"])
|
||||
require.NotNil(t, msg)
|
||||
require.Len(t, msg.Path, 3, "all paths should be present")
|
||||
assert.Equal(t, "/", msg.Path[0].Path)
|
||||
assert.Equal(t, "/api", msg.Path[1].Path)
|
||||
assert.Equal(t, "/ws", msg.Path[2].Path)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user