[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

View File

@@ -2,6 +2,7 @@ package grpc
import (
"context"
"fmt"
pb "github.com/golang/protobuf/proto" // nolint
log "github.com/sirupsen/logrus"
@@ -39,23 +40,38 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage)
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if exposeReq.Port > 65535 {
return nil, status.Errorf(codes.InvalidArgument, "port out of range: %d", exposeReq.Port)
}
if exposeReq.ListenPort > 65535 {
return nil, status.Errorf(codes.InvalidArgument, "listen_port out of range: %d", exposeReq.ListenPort)
}
mode, err := exposeProtocolToString(exposeReq.Protocol)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
NamePrefix: exposeReq.NamePrefix,
Port: int(exposeReq.Port),
Protocol: exposeProtocolToString(exposeReq.Protocol),
Domain: exposeReq.Domain,
Pin: exposeReq.Pin,
Password: exposeReq.Password,
UserGroups: exposeReq.UserGroups,
NamePrefix: exposeReq.NamePrefix,
Port: uint16(exposeReq.Port), //nolint:gosec // validated above
Mode: mode,
TargetProtocol: exposeTargetProtocol(exposeReq.Protocol),
Domain: exposeReq.Domain,
Pin: exposeReq.Pin,
Password: exposeReq.Password,
UserGroups: exposeReq.UserGroups,
ListenPort: uint16(exposeReq.ListenPort), //nolint:gosec // validated above
})
if err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{
ServiceName: created.ServiceName,
ServiceUrl: created.ServiceURL,
Domain: created.Domain,
ServiceName: created.ServiceName,
ServiceUrl: created.ServiceURL,
Domain: created.Domain,
PortAutoAssigned: created.PortAutoAssigned,
})
}
@@ -77,7 +93,12 @@ func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil {
serviceID, err := s.resolveServiceID(ctx, renewReq.Domain)
if err != nil {
return nil, mapExposeError(ctx, err)
}
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil {
return nil, mapExposeError(ctx, err)
}
@@ -102,7 +123,12 @@ func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil {
serviceID, err := s.resolveServiceID(ctx, stopReq.Domain)
if err != nil {
return nil, mapExposeError(ctx, err)
}
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, serviceID); err != nil {
return nil, mapExposeError(ctx, err)
}
@@ -180,13 +206,46 @@ func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
s.reverseProxyManager = mgr
}
func exposeProtocolToString(p proto.ExposeProtocol) string {
// resolveServiceID looks up the service by its globally unique domain.
func (s *Server) resolveServiceID(ctx context.Context, domain string) (string, error) {
if domain == "" {
return "", status.Errorf(codes.InvalidArgument, "domain is required")
}
svc, err := s.accountManager.GetStore().GetServiceByDomain(ctx, domain)
if err != nil {
return "", err
}
return svc.ID, nil
}
func exposeProtocolToString(p proto.ExposeProtocol) (string, error) {
switch p {
case proto.ExposeProtocol_EXPOSE_HTTP:
return "http"
case proto.ExposeProtocol_EXPOSE_HTTPS:
return "https"
case proto.ExposeProtocol_EXPOSE_HTTP, proto.ExposeProtocol_EXPOSE_HTTPS:
return "http", nil
case proto.ExposeProtocol_EXPOSE_TCP:
return "tcp", nil
case proto.ExposeProtocol_EXPOSE_UDP:
return "udp", nil
case proto.ExposeProtocol_EXPOSE_TLS:
return "tls", nil
default:
return "http"
return "", fmt.Errorf("unsupported expose protocol: %v", p)
}
}
// exposeTargetProtocol returns the target protocol for the given expose protocol.
// For HTTP mode, this is http or https (the scheme used to connect to the backend).
// For L4 modes, this is tcp or udp (the transport used to connect to the backend).
func exposeTargetProtocol(p proto.ExposeProtocol) string {
switch p {
case proto.ExposeProtocol_EXPOSE_HTTPS:
return rpservice.TargetProtoHTTPS
case proto.ExposeProtocol_EXPOSE_TCP, proto.ExposeProtocol_EXPOSE_TLS:
return rpservice.TargetProtoTCP
case proto.ExposeProtocol_EXPOSE_UDP:
return rpservice.TargetProtoUDP
default:
return rpservice.TargetProtoHTTP
}
}

View File

@@ -32,6 +32,7 @@ import (
proxyauth "github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
nbstatus "github.com/netbirdio/netbird/shared/management/status"
)
type ProxyOIDCConfig struct {
@@ -45,12 +46,6 @@ type ProxyOIDCConfig struct {
KeysLocation string
}
// ClusterInfo contains information about a proxy cluster.
type ClusterInfo struct {
Address string
ConnectedProxies int
}
// ProxyServiceServer implements the ProxyService gRPC server
type ProxyServiceServer struct {
proto.UnimplementedProxyServiceServer
@@ -61,9 +56,9 @@ type ProxyServiceServer struct {
// Manager for access logs
accessLogManager accesslogs.Manager
mu sync.RWMutex
// Manager for reverse proxy operations
serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController proxy.Controller
@@ -84,23 +79,26 @@ type ProxyServiceServer struct {
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
cancel context.CancelFunc
}
const pkceVerifierTTL = 10 * time.Minute
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
address string
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
proxyID string
address string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
@@ -109,6 +107,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
cancel: cancel,
}
go s.cleanupStaleProxies(ctx)
return s
@@ -130,11 +129,22 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
}
}
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.cancel()
}
// SetServiceManager sets the service manager. Must be called before serving.
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.mu.Lock()
defer s.mu.Unlock()
s.serviceManager = manager
}
// SetProxyController sets the proxy controller. Must be called before serving.
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
s.mu.Lock()
defer s.mu.Unlock()
s.proxyController = proxyController
}
@@ -157,12 +167,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
address: proxyAddress,
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
proxyID: proxyID,
address: proxyAddress,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
@@ -231,29 +242,18 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent.
// Only entries matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return fmt.Errorf("get services from store: %w", err)
}
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
var filtered []*rpservice.Service
for _, service := range services {
if !service.Enabled {
continue
}
if service.ProxyCluster == "" || service.ProxyCluster != conn.address {
continue
}
filtered = append(filtered, service)
mappings, err := s.snapshotServiceMappings(ctx, conn)
if err != nil {
return err
}
if len(filtered) == 0 {
if len(mappings) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true,
}); err != nil {
@@ -262,9 +262,30 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return nil
}
for i, service := range filtered {
// Generate one-time authentication token for each service in the snapshot
// Tokens are not persistent on the proxy, so we need to generate new ones on reconnection
for i, m := range mappings {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{m},
InitialSyncComplete: i == len(mappings)-1,
}); err != nil {
return fmt.Errorf("send proxy mapping: %w", err)
}
}
return nil
}
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services from store: %w", err)
}
var mappings []*proto.ProxyMapping
for _, service := range services {
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
continue
}
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
if err != nil {
log.WithFields(log.Fields{
@@ -274,25 +295,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
continue
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
service.ToProtoMapping(
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
token,
s.GetOIDCValidationConfig(),
),
},
InitialSyncComplete: i == len(filtered)-1,
}); err != nil {
log.WithFields(log.Fields{
"domain": service.Domain,
"account": service.AccountID,
}).WithError(err).Error("failed to send proxy mapping")
return fmt.Errorf("send proxy mapping: %w", err)
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
mappings = append(mappings, m)
}
return nil
return mappings, nil
}
// isProxyAddressValid validates a proxy address
@@ -305,8 +311,8 @@ func isProxyAddressValid(addr string) bool {
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
for {
select {
case msg := <-conn.sendChan:
if err := conn.stream.Send(msg); err != nil {
case resp := <-conn.sendChan:
if err := conn.stream.Send(resp); err != nil {
errChan <- err
return
}
@@ -361,12 +367,12 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
log.Debugf("Broadcasting service update to all connected proxy servers")
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
msg := s.perProxyMessage(update, conn.proxyID)
if msg == nil {
resp := s.perProxyMessage(update, conn.proxyID)
if resp == nil {
return true
}
select {
case conn.sendChan <- msg:
case conn.sendChan <- resp:
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
@@ -495,9 +501,40 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
Auth: m.Auth,
PassHostHeader: m.PassHostHeader,
RewriteRedirects: m.RewriteRedirects,
Mode: m.Mode,
ListenPort: m.ListenPort,
}
}
// ClusterSupportsCustomPorts returns whether any connected proxy in the given
// cluster reports custom port support. Returns nil if no proxy has reported
// capabilities (old proxies that predate the field).
func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool {
if s.proxyController == nil {
return nil
}
var hasCapabilities bool
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
connVal, ok := s.connectedProxies.Load(pid)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil {
continue
}
if *conn.capabilities.SupportsCustomPorts {
return ptr(true)
}
hasCapabilities = true
}
if hasCapabilities {
return ptr(false)
}
return nil
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
@@ -585,7 +622,7 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
return token, nil
}
// SendStatusUpdate handles status updates from proxy clients
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
accountID := req.GetAccountId()
serviceID := req.GetServiceId()
@@ -604,6 +641,17 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required")
}
internalStatus := protoStatusToInternal(protoStatus)
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
sErr, isNbErr := nbstatus.FromError(err)
if isNbErr && sErr.Type() == nbstatus.NotFound {
return nil, status.Errorf(codes.NotFound, "service %s not found", serviceID)
}
log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
}
if certificateIssued {
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
@@ -615,13 +663,6 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
}).Info("Certificate issued timestamp updated")
}
internalStatus := protoStatusToInternal(protoStatus)
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
}
log.WithFields(log.Fields{
"service_id": serviceID,
"account_id": accountID,
@@ -631,7 +672,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
return &proto.SendStatusUpdateResponse{}, nil
}
// protoStatusToInternal maps proto status to internal status
// protoStatusToInternal maps proto status to internal service status.
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
switch protoStatus {
case proto.ProxyStatus_PROXY_STATUS_PENDING:
@@ -1061,3 +1102,5 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
return fmt.Errorf("user not in allowed groups")
}
func ptr[T any](v T) *T { return &v }

View File

@@ -53,6 +53,10 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus
return nil
}
func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
return ptr(true)
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()
@@ -70,11 +74,17 @@ func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
return registerFakeProxyWithCaps(s, proxyID, clusterAddr, nil)
}
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
sendChan: ch,
proxyID: proxyID,
address: clusterAddr,
capabilities: caps,
sendChan: ch,
}
s.connectedProxies.Store(proxyID, conn)
@@ -83,15 +93,29 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
return ch
}
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
// drainMapping drains a single ProxyMapping from the channel.
func drainMapping(ch chan *proto.GetMappingUpdateResponse) *proto.ProxyMapping {
select {
case msg := <-ch:
return msg
case resp := <-ch:
if len(resp.Mapping) > 0 {
return resp.Mapping[0]
}
return nil
case <-time.After(time.Second):
return nil
}
}
// drainEmpty checks if a channel has no message within timeout.
func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool {
select {
case <-ch:
return false
case <-time.After(100 * time.Millisecond):
return true
}
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
@@ -129,10 +153,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokens := make([]string, numProxies)
for i, ch := range channels {
resp := drainChannel(ch)
require.NotNil(t, resp, "proxy %d should receive a message", i)
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
msg := resp.Mapping[0]
msg := drainMapping(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i)
assert.Equal(t, mapping.Domain, msg.Domain)
assert.Equal(t, mapping.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
@@ -181,16 +203,14 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
require.NotNil(t, resp1)
require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
msg1 := drainMapping(ch1)
msg2 := drainMapping(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
// Delete operations should not generate tokens
assert.Empty(t, resp1.Mapping[0].AuthToken)
assert.Empty(t, resp2.Mapping[0].AuthToken)
assert.Empty(t, msg1.AuthToken)
assert.Empty(t, msg2.AuthToken)
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
@@ -224,15 +244,10 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
s.SendServiceUpdate(update)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
require.NotNil(t, resp1)
require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
msg1 := resp1.Mapping[0]
msg2 := resp2.Mapping[0]
msg1 := drainMapping(ch1)
msg2 := drainMapping(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)
@@ -324,3 +339,314 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}
func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
// Proxy A supports custom ports.
chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
// Proxy B does NOT support custom ports (shared cloud proxy).
chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
ctx := context.Background()
// TLS passthrough works on all proxies regardless of custom port support.
tlsMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-tls",
AccountId: "account-1",
Domain: "db.example.com",
Mode: "tls",
ListenPort: 8443,
Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}},
}
s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster)
msgA := drainMapping(chA)
msgB := drainMapping(chB)
assert.NotNil(t, msgA, "proxy-a should receive TLS mapping")
assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)")
// Send an HTTP mapping: both should receive it.
httpMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-http",
AccountId: "account-1",
Domain: "app.example.com",
Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:80"}},
}
s.SendServiceUpdateToCluster(ctx, httpMapping, cluster)
msgA = drainMapping(chA)
msgB = drainMapping(chB)
assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping")
assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping")
}
func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
tlsMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-tls",
AccountId: "account-1",
Domain: "db.example.com",
Mode: "tls",
Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}},
}
s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster)
msg := drainMapping(chShared)
assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support")
}
// TestServiceModifyNotifications exercises every possible modification
// scenario for an existing service, verifying the correct update types
// reach the correct clusters.
func TestServiceModifyNotifications(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) {
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
chs := map[string]chan *proto.GetMappingUpdateResponse{
"cluster-a": registerFakeProxyWithCaps(s, "proxy-a", "cluster-a", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}),
"cluster-b": registerFakeProxyWithCaps(s, "proxy-b", "cluster-b", &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}),
}
return s, chs
}
httpMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping {
return &proto.ProxyMapping{
Type: updateType,
Id: "svc-1",
AccountId: "acct-1",
Domain: "app.example.com",
Path: []*proto.PathMapping{{Path: "/", Target: "http://10.0.0.1:8080"}},
}
}
tlsOnlyMapping := func(updateType proto.ProxyMappingUpdateType) *proto.ProxyMapping {
return &proto.ProxyMapping{
Type: updateType,
Id: "svc-1",
AccountId: "acct-1",
Domain: "app.example.com",
Mode: "tls",
ListenPort: 8443,
Path: []*proto.PathMapping{{Target: "10.0.0.1:443"}},
}
}
ctx := context.Background()
t.Run("targets changed sends MODIFIED to same cluster", func(t *testing.T) {
s, chs := newServer()
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg, "cluster-a should receive update")
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type)
assert.NotEmpty(t, msg.AuthToken, "MODIFIED should include token")
assert.True(t, drainEmpty(chs["cluster-b"]), "cluster-b should not receive update")
})
t.Run("auth config changed sends MODIFIED", func(t *testing.T) {
s, chs := newServer()
mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
mapping.Auth = &proto.Authentication{Password: true, Pin: true}
s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg)
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type)
assert.True(t, msg.Auth.Password)
assert.True(t, msg.Auth.Pin)
})
t.Run("HTTP to TLS transition sends MODIFIED with TLS config", func(t *testing.T) {
s, chs := newServer()
s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg)
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type)
assert.Equal(t, "tls", msg.Mode, "mode should be tls")
assert.Equal(t, int32(8443), msg.ListenPort)
assert.Len(t, msg.Path, 1, "should have one path entry with target address")
assert.Equal(t, "10.0.0.1:443", msg.Path[0].Target)
})
t.Run("TLS to HTTP transition sends MODIFIED without TLS", func(t *testing.T) {
s, chs := newServer()
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg)
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, msg.Type)
assert.Empty(t, msg.Mode, "mode should be empty for HTTP")
assert.True(t, len(msg.Path) > 0)
})
t.Run("TLS port changed sends MODIFIED with new port", func(t *testing.T) {
s, chs := newServer()
mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
mapping.ListenPort = 9443
s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg)
assert.Equal(t, int32(9443), msg.ListenPort)
})
t.Run("disable sends REMOVED to cluster", func(t *testing.T) {
s, chs := newServer()
// Manager sends Delete when service is disabled
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg)
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.Type)
assert.Empty(t, msg.AuthToken, "DELETE should not have token")
})
t.Run("enable sends CREATED to cluster", func(t *testing.T) {
s, chs := newServer()
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg)
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msg.Type)
assert.NotEmpty(t, msg.AuthToken)
})
t.Run("domain change with cluster change sends DELETE to old CREATE to new", func(t *testing.T) {
s, chs := newServer()
// This is the pattern the manager produces:
// 1. DELETE on old cluster
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a")
// 2. CREATE on new cluster
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-b")
msgA := drainMapping(chs["cluster-a"])
require.NotNil(t, msgA, "old cluster should receive DELETE")
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgA.Type)
msgB := drainMapping(chs["cluster-b"])
require.NotNil(t, msgB, "new cluster should receive CREATE")
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgB.Type)
assert.NotEmpty(t, msgB.AuthToken)
})
t.Run("domain change same cluster sends DELETE then CREATE", func(t *testing.T) {
s, chs := newServer()
// Domain changes within same cluster: manager sends DELETE (old domain) + CREATE (new domain).
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED), "cluster-a")
s.SendServiceUpdateToCluster(ctx, httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED), "cluster-a")
msgDel := drainMapping(chs["cluster-a"])
require.NotNil(t, msgDel, "same cluster should receive DELETE")
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msgDel.Type)
msgCreate := drainMapping(chs["cluster-a"])
require.NotNil(t, msgCreate, "same cluster should receive CREATE")
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, msgCreate.Type)
assert.NotEmpty(t, msgCreate.AuthToken)
})
t.Run("TLS passthrough sent to all proxies", func(t *testing.T) {
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
// TLS passthrough works on all proxies regardless of custom port support
s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster)
msgModern := drainMapping(chModern)
require.NotNil(t, msgModern, "modern proxy receives TLS update")
assert.Equal(t, "tls", msgModern.Mode)
msgLegacy := drainMapping(chLegacy)
assert.NotNil(t, msgLegacy, "legacy proxy should also receive TLS passthrough")
})
t.Run("TLS on default port NOT filtered for legacy proxy", func(t *testing.T) {
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
mapping.ListenPort = 0 // default port
s.SendServiceUpdateToCluster(ctx, mapping, cluster)
msgLegacy := drainMapping(chLegacy)
assert.NotNil(t, msgLegacy, "legacy proxy should receive TLS on default port")
})
t.Run("passthrough and rewrite flags propagated", func(t *testing.T) {
s, chs := newServer()
mapping := httpMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
mapping.PassHostHeader = true
mapping.RewriteRedirects = true
s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg)
assert.True(t, msg.PassHostHeader)
assert.True(t, msg.RewriteRedirects)
})
t.Run("multiple paths propagated in MODIFIED", func(t *testing.T) {
s, chs := newServer()
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED,
Id: "svc-multi",
AccountId: "acct-1",
Domain: "multi.example.com",
Path: []*proto.PathMapping{
{Path: "/", Target: "http://10.0.0.1:8080"},
{Path: "/api", Target: "http://10.0.0.2:9090"},
{Path: "/ws", Target: "http://10.0.0.3:3000"},
},
}
s.SendServiceUpdateToCluster(ctx, mapping, "cluster-a")
msg := drainMapping(chs["cluster-a"])
require.NotNil(t, msg)
require.Len(t, msg.Path, 3, "all paths should be present")
assert.Equal(t, "/", msg.Path[0].Path)
assert.Equal(t, "/api", msg.Path[1].Path)
assert.Equal(t, "/ws", msg.Path[2].Path)
})
}