feat(proxy): restore SyncMappings bidirectional stream with ack back-pressure

Reinstates the SyncMappings RPC that landed on origin/main and the
client-side fallback to GetMappingUpdate.

- proto: SyncMappings RPC + SyncMappingsRequest{Init|Ack} +
  SyncMappingsResponse messages.
- management proxy.go: SyncMappings server handler, recvSyncInit,
  sendSnapshotSync (per-batch send-then-wait-for-ack), drainRecv,
  waitForAck; proxyConnection.syncStream + sendResponse routes the
  same sendChan onto the bidi stream when set.
- proxy/server.go: trySyncMappings + handleSyncMappingsStream that
  acks after each batch is processed; outer loop tries SyncMappings
  first and falls back to GetMappingUpdate on Unimplemented.
  Capabilities lifted into proxyCapabilities() so both code paths
  use the same flags.
This commit is contained in:
mlsmaycon
2026-05-20 23:19:25 +02:00
parent 167ee08e14
commit 036e91cdea
7 changed files with 2129 additions and 390 deletions

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
@@ -78,10 +79,7 @@ type ProxyServiceServer struct {
// Manager for peers
peersManager peers.Manager
// Manager for users — also resolves user.AutoGroups to *types.Group
// records via GetUserWithGroups, used to stamp human-readable
// group names into JWT claims and onto upstream identity-bearing
// requests (llm_identity_inject's tags header).
// Manager for users
usersManager users.Manager
// Store for one-time authentication tokens
@@ -139,9 +137,12 @@ type proxyConnection struct {
tokenID string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
// syncStream is set when the proxy connected via SyncMappings.
// When non-nil, the sender goroutine uses this instead of stream.
syncStream proto.ProxyService_SyncMappingsServer
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
}
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
@@ -209,146 +210,322 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
s.proxyController = proxyController
}
// proxyConnectParams holds the validated parameters extracted from either
// a GetMappingUpdateRequest or a SyncMappingsInit message.
type proxyConnectParams struct {
proxyID string
address string
capabilities *proto.ProxyCapabilities
}
// GetMappingUpdate handles the control stream with proxy clients
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
ctx := stream.Context()
peerInfo := PeerIPFromContext(ctx)
log.Infof("New proxy connection from %s", peerInfo)
proxyID := req.GetProxyId()
if proxyID == "" {
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
}
proxyAddress := req.GetAddress()
if !isProxyAddressValid(proxyAddress) {
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
var accountID *string
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
accountID = token.AccountID
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
if err != nil {
return status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
}
}
var tokenID string
if token != nil {
tokenID = token.ID
}
sessionID := uuid.NewString()
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
oldConn := old.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"old_session_id": oldConn.sessionID,
"new_session_id": sessionID,
}).Info("Superseding existing proxy connection")
oldConn.cancel()
}
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress,
accountID: accountID,
tokenID: tokenID,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
}
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
Private: c.Private,
}
}
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
if err != nil {
cancel()
if accountID != nil {
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
return err
}
params.capabilities = req.GetCapabilities()
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
stream: stream,
})
if err != nil {
return err
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
if err := s.sendSnapshot(ctx, conn); err != nil {
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
}
cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
}
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
s.cleanupFailedSnapshot(stream.Context(), conn)
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"account_id": accountID,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
cancel()
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
}
// SyncMappings implements the bidirectional SyncMappings RPC.
// It mirrors GetMappingUpdate but provides application-level back-pressure:
// management waits for an ack from the proxy before sending the next batch.
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
init, err := recvSyncInit(stream)
if err != nil {
return err
}
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
if err != nil {
return err
}
params.capabilities = init.GetCapabilities()
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
syncStream: stream,
})
if err != nil {
return err
}
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
s.cleanupFailedSnapshot(stream.Context(), conn)
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.drainRecv(stream, errChan)
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
}
// recvSyncInit receives and validates the first message on a SyncMappings stream.
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
firstMsg, err := stream.Recv()
if err != nil {
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
}
init := firstMsg.GetInit()
if init == nil {
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
}
return init, nil
}
// validateProxyConnect validates the proxy ID and address, and checks cluster
// address availability for account-scoped tokens.
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
if proxyID == "" {
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
}
if !isProxyAddressValid(address) {
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
if err != nil {
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
}
}
return proxyConnectParams{proxyID: proxyID, address: address}, nil
}
// registerProxyConnection creates a proxyConnection, registers it with the
// proxy manager and cluster, and stores it in connectedProxies. The caller
// provides a partially initialised connSeed with stream-specific fields set;
// the remaining fields are filled in here.
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
peerInfo := PeerIPFromContext(ctx)
var accountID *string
var tokenID string
if token := GetProxyTokenFromContext(ctx); token != nil {
if token.AccountID != nil {
accountID = token.AccountID
}
tokenID = token.ID
}
sessionID := uuid.NewString()
s.supersedePriorConnection(params.proxyID, sessionID)
connCtx, cancel := context.WithCancel(ctx)
connSeed.proxyID = params.proxyID
connSeed.sessionID = sessionID
connSeed.address = params.address
connSeed.accountID = accountID
connSeed.tokenID = tokenID
connSeed.capabilities = params.capabilities
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
connSeed.ctx = connCtx
connSeed.cancel = cancel
var caps *proxy.Capabilities
if c := params.capabilities; c != nil {
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
}
}
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
if err != nil {
cancel()
if accountID != nil {
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
s.connectedProxies.Store(params.proxyID, connSeed)
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
}
return connSeed, proxyRecord, nil
}
// supersedePriorConnection cancels any existing connection for the given proxy.
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
oldConn := old.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"old_session_id": oldConn.sessionID,
"new_session_id": newSessionID,
}).Info("Superseding existing proxy connection")
oldConn.cancel()
}
}
// cleanupFailedSnapshot removes the connection from the cluster and store
// after a snapshot send failure.
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
}
}
conn.cancel()
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
}
}
// drainRecv consumes and discards messages from a bidirectional stream.
// The proxy sends an ack for every incremental update; we don't need them
// after the snapshot phase. Recv errors are forwarded to errChan.
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
for {
if _, err := stream.Recv(); err != nil {
errChan <- err
return
}
}
}
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
// and wait for termination. When bidi is true, normal stream closure (EOF,
// canceled) is treated as a clean disconnect rather than an error.
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
log.WithFields(log.Fields{
"proxy_id": conn.proxyID,
"session_id": conn.sessionID,
"address": conn.address,
"cluster_addr": conn.address,
"account_id": conn.accountID,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
cancel()
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
go s.heartbeat(connCtx, conn, proxyRecord)
defer s.disconnectProxy(conn)
go s.heartbeat(conn.ctx, conn, proxyRecord)
select {
case err := <-errChan:
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
case <-connCtx.Done():
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
return connCtx.Err()
if bidi && isStreamClosed(err) {
log.Infof("Proxy %s stream closed", conn.proxyID)
return nil
}
log.Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
case <-conn.ctx.Done():
log.Infof("Proxy %s context canceled", conn.proxyID)
return conn.ctx.Err()
}
}
// disconnectProxy removes the connection from cluster and store, unless it
// has already been superseded by a newer connection.
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
conn.cancel()
return
}
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
}
conn.cancel()
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
}
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
// one batch, then waits for the proxy to ack before sending the next.
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
if s.snapshotBatchSize <= 0 {
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
}
mappings, err := s.snapshotServiceMappings(ctx, conn)
if err != nil {
return err
}
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
for _, m := range mappings[i:end] {
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
if err != nil {
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
}
m.AuthToken = token
}
if err := stream.Send(&proto.SyncMappingsResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
if err := waitForAck(stream); err != nil {
return err
}
}
if len(mappings) == 0 {
if err := stream.Send(&proto.SyncMappingsResponse{
InitialSyncComplete: true,
}); err != nil {
return fmt.Errorf("send snapshot completion: %w", err)
}
if err := waitForAck(stream); err != nil {
return err
}
}
return nil
}
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
msg, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive ack: %w", err)
}
if msg.GetAck() == nil {
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
}
return nil
}
// heartbeat updates the proxy's last_seen timestamp every minute and
// disconnects the proxy if its access token has been revoked.
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
@@ -385,6 +562,9 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
if s.snapshotBatchSize <= 0 {
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
}
mappings, err := s.snapshotServiceMappings(ctx, conn)
if err != nil {
@@ -464,12 +644,26 @@ func isProxyAddressValid(addr string) bool {
return err == nil
}
// sender handles sending messages to proxy
// isStreamClosed returns true for errors that indicate normal stream
// termination: io.EOF, context cancellation, or gRPC Canceled.
func isStreamClosed(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return true
}
return status.Code(err) == codes.Canceled
}
// sender handles sending messages to proxy.
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
// otherwise the legacy GetMappingUpdateResponse stream is used.
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
for {
select {
case resp := <-conn.sendChan:
if err := conn.stream.Send(resp); err != nil {
if err := conn.sendResponse(resp); err != nil {
errChan <- err
return
}
@@ -479,6 +673,17 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
}
}
// sendResponse sends a mapping update on whichever stream the proxy connected with.
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
if conn.syncStream != nil {
return conn.syncStream.Send(&proto.SyncMappingsResponse{
Mapping: resp.Mapping,
InitialSyncComplete: resp.InitialSyncComplete,
})
}
return conn.stream.Send(resp)
}
// SendAccessLog processes access log from proxy
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
@@ -551,7 +756,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
}
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
connUpdate = filterMappingsForProxy(conn, connUpdate)
if len(connUpdate.Mapping) == 0 {
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
return true
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
@@ -663,7 +868,7 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
continue
}
if !proxyAcceptsMapping(conn, update) {
log.WithContext(ctx).Debugf("Skipping proxy %s: missing capability for mapping %s", proxyID, update.Id)
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
continue
}
msg := s.perProxyMessage(updateResponse, proxyID)
@@ -682,7 +887,10 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
}
// proxyAcceptsMapping returns whether the proxy can receive this mapping. Private requires SupportsPrivateService; custom-port L4 requires SupportsCustomPorts. Removes always pass.
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
// Private mappings require SupportsPrivateService; custom-port L4 mappings
// require SupportsCustomPorts. Remove operations always pass so proxies can
// clean up.
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
return true
@@ -696,23 +904,30 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
return true
}
// Old proxies that never reported capabilities don't understand
// custom port mappings.
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
}
// filterMappingsForProxy drops mappings the proxy can't safely receive.
// filterMappingsForProxy drops mappings the proxy cannot safely receive
// (e.g. private mappings to a proxy without SupportsPrivateService).
// Returns the input unchanged when no filtering is needed.
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
if update == nil {
if update == nil || len(update.Mapping) == 0 {
return update
}
filtered := make([]*proto.ProxyMapping, 0, len(update.Mapping))
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, m := range update.Mapping {
if !proxyAcceptsMapping(conn, m) {
continue
}
filtered = append(filtered, m)
kept = append(kept, m)
}
if len(kept) == len(update.Mapping) {
return update
}
return &proto.GetMappingUpdateResponse{
Mapping: filtered,
Mapping: kept,
InitialSyncComplete: update.InitialSyncComplete,
}
}
@@ -762,7 +977,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
Mode: m.Mode,
ListenPort: m.ListenPort,
AccessRestrictions: m.AccessRestrictions,
Private: m.Private,
}
}
@@ -771,7 +985,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
return nil, err
}
service, err := s.lookupServiceByID(ctx, req.GetAccountId(), req.GetId())
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
@@ -895,12 +1109,9 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
}
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
// into parallel id and name slices. Used to feed JWT claims and
// response messages from a single GetUserWithGroups call. Position i
// of the returned slices always pairs the same group: ids[i] is the
// group id, names[i] is its display name. Records the manager skipped
// (orphan ids in user.AutoGroups with no matching group row) are
// already absent here, so the consumer can rely on positional pairing.
// into parallel id and name slices. ids[i] and names[i] always pair to
// the same group. nil entries (orphan ids the manager couldn't resolve)
// are skipped so the consumer can rely on positional pairing.
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
if len(groups) == 0 {
return nil, nil
@@ -1063,7 +1274,19 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme")
}
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
if _, err := s.getAccountServiceByDomain(ctx, req.GetAccountId(), redirectURL.Hostname()); err != nil {
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil {
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
}
var found bool
for _, service := range services {
if service.Domain == redirectURL.Hostname() {
found = true
break
}
}
if !found {
log.WithContext(ctx).Debugf("OIDC redirect URL %q does not match any service domain", redirectURL.Hostname())
return nil, status.Errorf(codes.FailedPrecondition, "service not found in store")
}
@@ -1266,11 +1489,6 @@ func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, acco
return nil, fmt.Errorf("service not found for domain %s in account %s", domain, accountID)
}
// lookupServiceByID resolves a service by ID via the persisted store.
func (s *ProxyServiceServer) lookupServiceByID(ctx context.Context, accountID, serviceID string) (*rpservice.Service, error) {
return s.serviceManager.GetServiceByID(ctx, accountID, serviceID)
}
// ValidateSession validates a session token and checks if the user has access to the domain.
func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.ValidateSessionRequest) (*proto.ValidateSessionResponse, error) {
domain := req.GetDomain()
@@ -1392,138 +1610,6 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
return s.serviceManager.GetServiceByDomain(ctx, domain)
}
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
// checks the peer's group membership against the service's access groups.
// Peers without a user (machine agents, automation workloads) are
// first-class callers, so authorisation runs off peer-group memberships
// rather than the optional owning user's auto-groups.
//
// The minted session JWT carries peer.ID (Subject) and a display identity
// (Email): user.Email when the peer is attached to a user, else peer.Name.
// Downstream identity-stamping middlewares emit this display identity on
// upstream requests, so per-human attribution lines up across a user's
// devices and unattached peers still produce a stable, human-readable id.
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
domain := req.GetDomain()
tunnelIPStr := req.GetTunnelIp()
if domain == "" || tunnelIPStr == "" {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "missing domain or tunnel_ip",
}, nil
}
tunnelIP := net.ParseIP(tunnelIPStr)
if tunnelIP == nil {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "invalid_tunnel_ip",
}, nil
}
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Debug("ValidateTunnelPeer: service not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "service_not_found",
}, nil
}
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
if err != nil || peer == nil {
log.WithFields(log.Fields{
"domain": domain,
"tunnel_ip": tunnelIPStr,
}).Debug("ValidateTunnelPeer: peer not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"peer_id": peer.ID,
"error": err.Error(),
}).Debug("ValidateTunnelPeer: peer groups lookup failed")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the
// human is the principal — multiple peers belonging to the same
// user share one access-log identity AND one consumption-counter
// dimension (caps are per-human, not per-device). Unlinked peers
// (machine agents, automation workloads) are their own principals
// keyed on peer.ID. displayIdentity is what upstream gateways
// (LiteLLM, Portkey) tag spend with — user.Email when linked, the
// peer's hostname when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{
"domain": domain,
"peer_id": peer.ID,
"error": err.Error(),
}).Debug("ValidateTunnelPeer: access denied")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
UserId: principalID,
UserEmail: displayIdentity,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
if err != nil {
return nil, err
}
log.WithFields(log.Fields{
"domain": domain,
"tunnel_ip": tunnelIPStr,
"peer_id": peer.ID,
"peer_name": peer.Name,
"principal_id": principalID,
"identity": displayIdentity,
}).Debug("ValidateTunnelPeer: access granted")
return &proto.ValidateTunnelPeerResponse{
Valid: true,
UserId: principalID,
UserEmail: displayIdentity,
SessionToken: token,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
@@ -1548,7 +1634,122 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
return fmt.Errorf("user not in allowed groups")
}
// checkPeerGroupAccess gates ValidateTunnelPeer; private with empty AccessGroups fails closed.
func ptr[T any](v T) *T { return &v }
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
// checks the peer's group membership against the service's access groups.
// Peers without a user (machine agents, automation workloads) are first-class
// callers; authorisation runs off peer-group memberships rather than the
// optional owning user's auto-groups. On success a session JWT is minted so
// the proxy can install a cookie and skip subsequent management round-trips.
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
domain := req.GetDomain()
tunnelIPStr := req.GetTunnelIp()
if domain == "" || tunnelIPStr == "" {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "missing domain or tunnel_ip",
}, nil
}
tunnelIP := net.ParseIP(tunnelIPStr)
if tunnelIP == nil {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "invalid_tunnel_ip",
}, nil
}
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "service_not_found",
}, nil
}
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
if err != nil || peer == nil {
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
UserId: principalID,
UserEmail: displayIdentity,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
if err != nil {
return nil, err
}
log.WithFields(log.Fields{
"domain": domain,
"tunnel_ip": tunnelIPStr,
"peer_id": peer.ID,
"principal_id": principalID,
}).Debug("ValidateTunnelPeer: access granted")
return &proto.ValidateTunnelPeerResponse{
Valid: true,
UserId: principalID,
UserEmail: displayIdentity,
SessionToken: token,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security
// boundary and must not trust upstream state). Bearer-auth services authorise
// against DistributionGroups when populated. Non-private non-bearer services
// are open.
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
if service.Private {
if len(service.AccessGroups) == 0 {
@@ -1556,28 +1757,26 @@ func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) err
}
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
if len(service.Auth.BearerAuth.DistributionGroups) == 0 {
return nil
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
}
return nil
}
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
// else a non-nil error.
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
allowedSet := make(map[string]bool, len(allowedGroups))
for _, groupID := range allowedGroups {
allowedSet[groupID] = true
if len(allowedGroups) == 0 {
return fmt.Errorf("no allowed groups configured")
}
for _, groupID := range peerGroupIDs {
if allowedSet[groupID] {
allowed := make(map[string]struct{}, len(allowedGroups))
for _, g := range allowedGroups {
allowed[g] = struct{}{}
}
for _, g := range peerGroupIDs {
if _, ok := allowed[g]; ok {
return nil
}
}
return fmt.Errorf("peer not in allowed groups")
}
func ptr[T any](v T) *T { return &v }