diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 7680fb2eb..25bf52d8c 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "io" "net" "net/http" "net/url" @@ -78,10 +79,7 @@ type ProxyServiceServer struct { // Manager for peers peersManager peers.Manager - // Manager for users — also resolves user.AutoGroups to *types.Group - // records via GetUserWithGroups, used to stamp human-readable - // group names into JWT claims and onto upstream identity-bearing - // requests (llm_identity_inject's tags header). + // Manager for users usersManager users.Manager // Store for one-time authentication tokens @@ -139,9 +137,12 @@ type proxyConnection struct { tokenID string capabilities *proto.ProxyCapabilities stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.GetMappingUpdateResponse - ctx context.Context - cancel context.CancelFunc + // syncStream is set when the proxy connected via SyncMappings. + // When non-nil, the sender goroutine uses this instead of stream. + syncStream proto.ProxyService_SyncMappingsServer + sendChan chan *proto.GetMappingUpdateResponse + ctx context.Context + cancel context.CancelFunc } func enforceAccountScope(ctx context.Context, requestAccountID string) error { @@ -209,146 +210,322 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller s.proxyController = proxyController } +// proxyConnectParams holds the validated parameters extracted from either +// a GetMappingUpdateRequest or a SyncMappingsInit message. +type proxyConnectParams struct { + proxyID string + address string + capabilities *proto.ProxyCapabilities +} + // GetMappingUpdate handles the control stream with proxy clients func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error { - ctx := stream.Context() - - peerInfo := PeerIPFromContext(ctx) - log.Infof("New proxy connection from %s", peerInfo) - - proxyID := req.GetProxyId() - if proxyID == "" { - return status.Errorf(codes.InvalidArgument, "proxy_id is required") - } - - proxyAddress := req.GetAddress() - if !isProxyAddressValid(proxyAddress) { - return status.Errorf(codes.InvalidArgument, "proxy address is invalid") - } - - var accountID *string - token := GetProxyTokenFromContext(ctx) - if token != nil && token.AccountID != nil { - accountID = token.AccountID - - available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID) - if err != nil { - return status.Errorf(codes.Internal, "check cluster address: %v", err) - } - if !available { - return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress) - } - } - - var tokenID string - if token != nil { - tokenID = token.ID - } - - sessionID := uuid.NewString() - - if old, loaded := s.connectedProxies.Load(proxyID); loaded { - oldConn := old.(*proxyConnection) - log.WithFields(log.Fields{ - "proxy_id": proxyID, - "old_session_id": oldConn.sessionID, - "new_session_id": sessionID, - }).Info("Superseding existing proxy connection") - oldConn.cancel() - } - - connCtx, cancel := context.WithCancel(ctx) - conn := &proxyConnection{ - proxyID: proxyID, - sessionID: sessionID, - address: proxyAddress, - accountID: accountID, - tokenID: tokenID, - capabilities: req.GetCapabilities(), - stream: stream, - sendChan: make(chan *proto.GetMappingUpdateResponse, 100), - ctx: connCtx, - cancel: cancel, - } - - var caps *proxy.Capabilities - if c := req.GetCapabilities(); c != nil { - caps = &proxy.Capabilities{ - SupportsCustomPorts: c.SupportsCustomPorts, - RequireSubdomain: c.RequireSubdomain, - SupportsCrowdsec: c.SupportsCrowdsec, - Private: c.Private, - } - } - proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps) + params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context()) if err != nil { - cancel() - if accountID != nil { - return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err) - } - log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - return status.Errorf(codes.Internal, "register proxy in database: %v", err) + return err + } + params.capabilities = req.GetCapabilities() + + conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{ + stream: stream, + }) + if err != nil { + return err } - s.connectedProxies.Store(proxyID, conn) - if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) - } - - if err := s.sendSnapshot(ctx, conn); err != nil { - if s.connectedProxies.CompareAndDelete(proxyID, conn) { - if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil { - log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr) - } - } - cancel() - if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil { - log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr) - } - return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) + if err := s.sendSnapshot(stream.Context(), conn); err != nil { + s.cleanupFailedSnapshot(stream.Context(), conn) + return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err) } errChan := make(chan error, 2) go s.sender(conn, errChan) - log.WithFields(log.Fields{ - "proxy_id": proxyID, - "session_id": sessionID, - "address": proxyAddress, - "cluster_addr": proxyAddress, - "account_id": accountID, - "total_proxies": len(s.GetConnectedProxies()), - }).Info("Proxy registered in cluster") - defer func() { - if !s.connectedProxies.CompareAndDelete(proxyID, conn) { - log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID) - cancel() + return s.serveProxyConnection(conn, proxyRecord, errChan, false) +} + +// SyncMappings implements the bidirectional SyncMappings RPC. +// It mirrors GetMappingUpdate but provides application-level back-pressure: +// management waits for an ack from the proxy before sending the next batch. +func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error { + init, err := recvSyncInit(stream) + if err != nil { + return err + } + + params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context()) + if err != nil { + return err + } + params.capabilities = init.GetCapabilities() + + conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{ + syncStream: stream, + }) + if err != nil { + return err + } + + if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil { + s.cleanupFailedSnapshot(stream.Context(), conn) + return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err) + } + + errChan := make(chan error, 2) + go s.sender(conn, errChan) + go s.drainRecv(stream, errChan) + + return s.serveProxyConnection(conn, proxyRecord, errChan, true) +} + +// recvSyncInit receives and validates the first message on a SyncMappings stream. +func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) { + firstMsg, err := stream.Recv() + if err != nil { + return nil, status.Errorf(codes.Internal, "receive init: %v", err) + } + init := firstMsg.GetInit() + if init == nil { + return nil, status.Errorf(codes.InvalidArgument, "first message must be init") + } + return init, nil +} + +// validateProxyConnect validates the proxy ID and address, and checks cluster +// address availability for account-scoped tokens. +func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) { + if proxyID == "" { + return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required") + } + if !isProxyAddressValid(address) { + return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid") + } + + token := GetProxyTokenFromContext(ctx) + if token != nil && token.AccountID != nil { + available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID) + if err != nil { + return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err) + } + if !available { + return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address) + } + } + + return proxyConnectParams{proxyID: proxyID, address: address}, nil +} + +// registerProxyConnection creates a proxyConnection, registers it with the +// proxy manager and cluster, and stores it in connectedProxies. The caller +// provides a partially initialised connSeed with stream-specific fields set; +// the remaining fields are filled in here. +func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) { + peerInfo := PeerIPFromContext(ctx) + + var accountID *string + var tokenID string + if token := GetProxyTokenFromContext(ctx); token != nil { + if token.AccountID != nil { + accountID = token.AccountID + } + tokenID = token.ID + } + + sessionID := uuid.NewString() + s.supersedePriorConnection(params.proxyID, sessionID) + + connCtx, cancel := context.WithCancel(ctx) + connSeed.proxyID = params.proxyID + connSeed.sessionID = sessionID + connSeed.address = params.address + connSeed.accountID = accountID + connSeed.tokenID = tokenID + connSeed.capabilities = params.capabilities + connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100) + connSeed.ctx = connCtx + connSeed.cancel = cancel + + var caps *proxy.Capabilities + if c := params.capabilities; c != nil { + caps = &proxy.Capabilities{ + SupportsCustomPorts: c.SupportsCustomPorts, + RequireSubdomain: c.RequireSubdomain, + SupportsCrowdsec: c.SupportsCrowdsec, + } + } + + proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps) + if err != nil { + cancel() + if accountID != nil { + return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err) + } + log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err) + return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err) + } + + s.connectedProxies.Store(params.proxyID, connSeed) + if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err) + } + + return connSeed, proxyRecord, nil +} + +// supersedePriorConnection cancels any existing connection for the given proxy. +func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) { + if old, loaded := s.connectedProxies.Load(proxyID); loaded { + oldConn := old.(*proxyConnection) + log.WithFields(log.Fields{ + "proxy_id": proxyID, + "old_session_id": oldConn.sessionID, + "new_session_id": newSessionID, + }).Info("Superseding existing proxy connection") + oldConn.cancel() + } +} + +// cleanupFailedSnapshot removes the connection from the cluster and store +// after a snapshot send failure. +func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) { + if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) { + if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err) + } + } + conn.cancel() + if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err) + } +} + +// drainRecv consumes and discards messages from a bidirectional stream. +// The proxy sends an ack for every incremental update; we don't need them +// after the snapshot phase. Recv errors are forwarded to errChan. +func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) { + for { + if _, err := stream.Recv(); err != nil { + errChan <- err return } + } +} - if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { - log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) - } - if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil { - log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) - } +// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender, +// and wait for termination. When bidi is true, normal stream closure (EOF, +// canceled) is treated as a clean disconnect rather than an error. +func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error { + log.WithFields(log.Fields{ + "proxy_id": conn.proxyID, + "session_id": conn.sessionID, + "address": conn.address, + "cluster_addr": conn.address, + "account_id": conn.accountID, + "total_proxies": len(s.GetConnectedProxies()), + }).Info("Proxy registered in cluster") - cancel() - log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) - }() - - go s.heartbeat(connCtx, conn, proxyRecord) + defer s.disconnectProxy(conn) + go s.heartbeat(conn.ctx, conn, proxyRecord) select { case err := <-errChan: - log.WithContext(ctx).Warnf("Failed to send update: %v", err) - return fmt.Errorf("send update to proxy %s: %w", proxyID, err) - case <-connCtx.Done(): - log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID) - return connCtx.Err() + if bidi && isStreamClosed(err) { + log.Infof("Proxy %s stream closed", conn.proxyID) + return nil + } + log.Warnf("Failed to send update: %v", err) + return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err) + case <-conn.ctx.Done(): + log.Infof("Proxy %s context canceled", conn.proxyID) + return conn.ctx.Err() } } +// disconnectProxy removes the connection from cluster and store, unless it +// has already been superseded by a newer connection. +func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) { + if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) { + log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID) + conn.cancel() + return + } + + if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil { + log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err) + } + if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil { + log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err) + } + + conn.cancel() + log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID) +} + +// sendSnapshotSync sends the initial snapshot with back-pressure: it sends +// one batch, then waits for the proxy to ack before sending the next. +func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error { + if !isProxyAddressValid(conn.address) { + return fmt.Errorf("proxy address is invalid") + } + if s.snapshotBatchSize <= 0 { + return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize) + } + + mappings, err := s.snapshotServiceMappings(ctx, conn) + if err != nil { + return err + } + + for i := 0; i < len(mappings); i += s.snapshotBatchSize { + end := i + s.snapshotBatchSize + if end > len(mappings) { + end = len(mappings) + } + for _, m := range mappings[i:end] { + token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL()) + if err != nil { + return fmt.Errorf("generate auth token for service %s: %w", m.Id, err) + } + m.AuthToken = token + } + if err := stream.Send(&proto.SyncMappingsResponse{ + Mapping: mappings[i:end], + InitialSyncComplete: end == len(mappings), + }); err != nil { + return fmt.Errorf("send snapshot batch: %w", err) + } + + if err := waitForAck(stream); err != nil { + return err + } + } + + if len(mappings) == 0 { + if err := stream.Send(&proto.SyncMappingsResponse{ + InitialSyncComplete: true, + }); err != nil { + return fmt.Errorf("send snapshot completion: %w", err) + } + + if err := waitForAck(stream); err != nil { + return err + } + } + + return nil +} + +func waitForAck(stream proto.ProxyService_SyncMappingsServer) error { + msg, err := stream.Recv() + if err != nil { + return fmt.Errorf("receive ack: %w", err) + } + if msg.GetAck() == nil { + return fmt.Errorf("expected ack, got %T", msg.GetMsg()) + } + return nil +} + // heartbeat updates the proxy's last_seen timestamp every minute and // disconnects the proxy if its access token has been revoked. func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) { @@ -385,6 +562,9 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") } + if s.snapshotBatchSize <= 0 { + return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize) + } mappings, err := s.snapshotServiceMappings(ctx, conn) if err != nil { @@ -464,12 +644,26 @@ func isProxyAddressValid(addr string) bool { return err == nil } -// sender handles sending messages to proxy +// isStreamClosed returns true for errors that indicate normal stream +// termination: io.EOF, context cancellation, or gRPC Canceled. +func isStreamClosed(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + return true + } + return status.Code(err) == codes.Canceled +} + +// sender handles sending messages to proxy. +// When conn.syncStream is set the message is sent as SyncMappingsResponse; +// otherwise the legacy GetMappingUpdateResponse stream is used. func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) { for { select { case resp := <-conn.sendChan: - if err := conn.stream.Send(resp); err != nil { + if err := conn.sendResponse(resp); err != nil { errChan <- err return } @@ -479,6 +673,17 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) } } +// sendResponse sends a mapping update on whichever stream the proxy connected with. +func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error { + if conn.syncStream != nil { + return conn.syncStream.Send(&proto.SyncMappingsResponse{ + Mapping: resp.Mapping, + InitialSyncComplete: resp.InitialSyncComplete, + }) + } + return conn.stream.Send(resp) +} + // SendAccessLog processes access log from proxy func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) { accessLog := req.GetLog() @@ -551,7 +756,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes } // Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService). connUpdate = filterMappingsForProxy(conn, connUpdate) - if len(connUpdate.Mapping) == 0 { + if connUpdate == nil || len(connUpdate.Mapping) == 0 { return true } resp := s.perProxyMessage(connUpdate, conn.proxyID) @@ -663,7 +868,7 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd continue } if !proxyAcceptsMapping(conn, update) { - log.WithContext(ctx).Debugf("Skipping proxy %s: missing capability for mapping %s", proxyID, update.Id) + log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) continue } msg := s.perProxyMessage(updateResponse, proxyID) @@ -682,7 +887,10 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd } } -// proxyAcceptsMapping returns whether the proxy can receive this mapping. Private requires SupportsPrivateService; custom-port L4 requires SupportsCustomPorts. Removes always pass. +// proxyAcceptsMapping returns whether the proxy can receive this mapping. +// Private mappings require SupportsPrivateService; custom-port L4 mappings +// require SupportsCustomPorts. Remove operations always pass so proxies can +// clean up. func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool { if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { return true @@ -696,23 +904,30 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo if mapping.ListenPort == 0 || mapping.Mode == "tls" { return true } + // Old proxies that never reported capabilities don't understand + // custom port mappings. return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil } -// filterMappingsForProxy drops mappings the proxy can't safely receive. +// filterMappingsForProxy drops mappings the proxy cannot safely receive +// (e.g. private mappings to a proxy without SupportsPrivateService). +// Returns the input unchanged when no filtering is needed. func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse { - if update == nil { + if update == nil || len(update.Mapping) == 0 { return update } - filtered := make([]*proto.ProxyMapping, 0, len(update.Mapping)) + kept := make([]*proto.ProxyMapping, 0, len(update.Mapping)) for _, m := range update.Mapping { if !proxyAcceptsMapping(conn, m) { continue } - filtered = append(filtered, m) + kept = append(kept, m) + } + if len(kept) == len(update.Mapping) { + return update } return &proto.GetMappingUpdateResponse{ - Mapping: filtered, + Mapping: kept, InitialSyncComplete: update.InitialSyncComplete, } } @@ -762,7 +977,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { Mode: m.Mode, ListenPort: m.ListenPort, AccessRestrictions: m.AccessRestrictions, - Private: m.Private, } } @@ -771,7 +985,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen return nil, err } - service, err := s.lookupServiceByID(ctx, req.GetAccountId(), req.GetId()) + service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { log.WithContext(ctx).Debugf("failed to get service from store: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err) @@ -895,12 +1109,9 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic } // pairGroupIDsAndNames splits a slice of resolved *types.Group records -// into parallel id and name slices. Used to feed JWT claims and -// response messages from a single GetUserWithGroups call. Position i -// of the returned slices always pairs the same group: ids[i] is the -// group id, names[i] is its display name. Records the manager skipped -// (orphan ids in user.AutoGroups with no matching group row) are -// already absent here, so the consumer can rely on positional pairing. +// into parallel id and name slices. ids[i] and names[i] always pair to +// the same group. nil entries (orphan ids the manager couldn't resolve) +// are skipped so the consumer can rely on positional pairing. func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) { if len(groups) == 0 { return nil, nil @@ -1063,7 +1274,19 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme") } // Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection. - if _, err := s.getAccountServiceByDomain(ctx, req.GetAccountId(), redirectURL.Hostname()); err != nil { + services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId()) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account services: %v", err) + return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err) + } + var found bool + for _, service := range services { + if service.Domain == redirectURL.Hostname() { + found = true + break + } + } + if !found { log.WithContext(ctx).Debugf("OIDC redirect URL %q does not match any service domain", redirectURL.Hostname()) return nil, status.Errorf(codes.FailedPrecondition, "service not found in store") } @@ -1266,11 +1489,6 @@ func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, acco return nil, fmt.Errorf("service not found for domain %s in account %s", domain, accountID) } -// lookupServiceByID resolves a service by ID via the persisted store. -func (s *ProxyServiceServer) lookupServiceByID(ctx context.Context, accountID, serviceID string) (*rpservice.Service, error) { - return s.serviceManager.GetServiceByID(ctx, accountID, serviceID) -} - // ValidateSession validates a session token and checks if the user has access to the domain. func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.ValidateSessionRequest) (*proto.ValidateSessionResponse, error) { domain := req.GetDomain() @@ -1392,138 +1610,6 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri return s.serviceManager.GetServiceByDomain(ctx, domain) } -// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and -// checks the peer's group membership against the service's access groups. -// Peers without a user (machine agents, automation workloads) are -// first-class callers, so authorisation runs off peer-group memberships -// rather than the optional owning user's auto-groups. -// -// The minted session JWT carries peer.ID (Subject) and a display identity -// (Email): user.Email when the peer is attached to a user, else peer.Name. -// Downstream identity-stamping middlewares emit this display identity on -// upstream requests, so per-human attribution lines up across a user's -// devices and unattached peers still produce a stable, human-readable id. -func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { - domain := req.GetDomain() - tunnelIPStr := req.GetTunnelIp() - - if domain == "" || tunnelIPStr == "" { - return &proto.ValidateTunnelPeerResponse{ - Valid: false, - DeniedReason: "missing domain or tunnel_ip", - }, nil - } - - tunnelIP := net.ParseIP(tunnelIPStr) - if tunnelIP == nil { - return &proto.ValidateTunnelPeerResponse{ - Valid: false, - DeniedReason: "invalid_tunnel_ip", - }, nil - } - - service, err := s.getServiceByDomain(ctx, domain) - if err != nil { - log.WithFields(log.Fields{ - "domain": domain, - "error": err.Error(), - }).Debug("ValidateTunnelPeer: service not found") - //nolint:nilerr - return &proto.ValidateTunnelPeerResponse{ - Valid: false, - DeniedReason: "service_not_found", - }, nil - } - - peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP) - if err != nil || peer == nil { - log.WithFields(log.Fields{ - "domain": domain, - "tunnel_ip": tunnelIPStr, - }).Debug("ValidateTunnelPeer: peer not found") - //nolint:nilerr - return &proto.ValidateTunnelPeerResponse{ - Valid: false, - DeniedReason: "peer_not_found", - }, nil - } - - _, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID) - if err != nil { - log.WithFields(log.Fields{ - "domain": domain, - "peer_id": peer.ID, - "error": err.Error(), - }).Debug("ValidateTunnelPeer: peer groups lookup failed") - //nolint:nilerr - return &proto.ValidateTunnelPeerResponse{ - Valid: false, - DeniedReason: "peer_not_found", - }, nil - } - - groupIDs, groupNames := pairGroupIDsAndNames(peerGroups) - - // Resolve the principal: when the peer is linked to a user, the - // human is the principal — multiple peers belonging to the same - // user share one access-log identity AND one consumption-counter - // dimension (caps are per-human, not per-device). Unlinked peers - // (machine agents, automation workloads) are their own principals - // keyed on peer.ID. displayIdentity is what upstream gateways - // (LiteLLM, Portkey) tag spend with — user.Email when linked, the - // peer's hostname when not. - principalID := peer.ID - displayIdentity := peer.Name - if peer.UserID != "" { - if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil { - principalID = user.Id - if user.Email != "" { - displayIdentity = user.Email - } - } - } - - if err := checkPeerGroupAccess(service, groupIDs); err != nil { - log.WithFields(log.Fields{ - "domain": domain, - "peer_id": peer.ID, - "error": err.Error(), - }).Debug("ValidateTunnelPeer: access denied") - //nolint:nilerr - return &proto.ValidateTunnelPeerResponse{ - Valid: false, - UserId: principalID, - UserEmail: displayIdentity, - DeniedReason: "not_in_group", - PeerGroupIds: groupIDs, - PeerGroupNames: groupNames, - }, nil - } - - token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames) - if err != nil { - return nil, err - } - - log.WithFields(log.Fields{ - "domain": domain, - "tunnel_ip": tunnelIPStr, - "peer_id": peer.ID, - "peer_name": peer.Name, - "principal_id": principalID, - "identity": displayIdentity, - }).Debug("ValidateTunnelPeer: access granted") - - return &proto.ValidateTunnelPeerResponse{ - Valid: true, - UserId: principalID, - UserEmail: displayIdentity, - SessionToken: token, - PeerGroupIds: groupIDs, - PeerGroupNames: groupNames, - }, nil -} - func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled { return nil @@ -1548,7 +1634,122 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user * return fmt.Errorf("user not in allowed groups") } -// checkPeerGroupAccess gates ValidateTunnelPeer; private with empty AccessGroups fails closed. +func ptr[T any](v T) *T { return &v } + +// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and +// checks the peer's group membership against the service's access groups. +// Peers without a user (machine agents, automation workloads) are first-class +// callers; authorisation runs off peer-group memberships rather than the +// optional owning user's auto-groups. On success a session JWT is minted so +// the proxy can install a cookie and skip subsequent management round-trips. +func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + domain := req.GetDomain() + tunnelIPStr := req.GetTunnelIp() + + if domain == "" || tunnelIPStr == "" { + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "missing domain or tunnel_ip", + }, nil + } + + tunnelIP := net.ParseIP(tunnelIPStr) + if tunnelIP == nil { + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "invalid_tunnel_ip", + }, nil + } + + service, err := s.getServiceByDomain(ctx, domain) + if err != nil { + log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found") + //nolint:nilerr + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "service_not_found", + }, nil + } + + peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP) + if err != nil || peer == nil { + log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found") + //nolint:nilerr + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "peer_not_found", + }, nil + } + + _, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID) + if err != nil { + log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed") + //nolint:nilerr + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "peer_not_found", + }, nil + } + + groupIDs, groupNames := pairGroupIDsAndNames(peerGroups) + + // Resolve the principal: when the peer is linked to a user, the human + // is the principal so multiple peers owned by the same user share a + // single identity. Unlinked peers (machine agents) are their own + // principal keyed on peer.ID. displayIdentity is what upstream gateways + // tag spend with — user.Email when linked, peer.Name when not. + principalID := peer.ID + displayIdentity := peer.Name + if peer.UserID != "" { + if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil { + principalID = user.Id + if user.Email != "" { + displayIdentity = user.Email + } + } + } + + if err := checkPeerGroupAccess(service, groupIDs); err != nil { + log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied") + //nolint:nilerr + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + UserId: principalID, + UserEmail: displayIdentity, + DeniedReason: "not_in_group", + PeerGroupIds: groupIDs, + PeerGroupNames: groupNames, + }, nil + } + + token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames) + if err != nil { + return nil, err + } + + log.WithFields(log.Fields{ + "domain": domain, + "tunnel_ip": tunnelIPStr, + "peer_id": peer.ID, + "principal_id": principalID, + }).Debug("ValidateTunnelPeer: access granted") + + return &proto.ValidateTunnelPeerResponse{ + Valid: true, + UserId: principalID, + UserEmail: displayIdentity, + SessionToken: token, + PeerGroupIds: groupIDs, + PeerGroupNames: groupNames, + }, nil +} + +// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required +// groups. Private services authorise against AccessGroups (empty list fails +// closed — Validate() rejects that at save time but the RPC is the security +// boundary and must not trust upstream state). Bearer-auth services authorise +// against DistributionGroups when populated. Non-private non-bearer services +// are open. func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error { if service.Private { if len(service.AccessGroups) == 0 { @@ -1556,28 +1757,26 @@ func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) err } return matchAnyGroup(service.AccessGroups, peerGroupIDs) } - - if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled { - if len(service.Auth.BearerAuth.DistributionGroups) == 0 { - return nil - } + if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 { return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs) } - return nil } +// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups, +// else a non-nil error. func matchAnyGroup(allowedGroups, peerGroupIDs []string) error { - allowedSet := make(map[string]bool, len(allowedGroups)) - for _, groupID := range allowedGroups { - allowedSet[groupID] = true + if len(allowedGroups) == 0 { + return fmt.Errorf("no allowed groups configured") } - for _, groupID := range peerGroupIDs { - if allowedSet[groupID] { + allowed := make(map[string]struct{}, len(allowedGroups)) + for _, g := range allowedGroups { + allowed[g] = struct{}{} + } + for _, g := range peerGroupIDs { + if _, ok := allowed[g]; ok { return nil } } return fmt.Errorf("peer not in allowed groups") } - -func ptr[T any](v T) *T { return &v } diff --git a/management/internals/shared/grpc/sync_mappings_test.go b/management/internals/shared/grpc/sync_mappings_test.go new file mode 100644 index 000000000..97f6183bb --- /dev/null +++ b/management/internals/shared/grpc/sync_mappings_test.go @@ -0,0 +1,411 @@ +package grpc + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records +// sent messages and returns pre-loaded ack responses from Recv. +type syncRecordingStream struct { + grpc.ServerStream + + mu sync.Mutex + sent []*proto.SyncMappingsResponse + recvMsgs []*proto.SyncMappingsRequest + recvIdx int +} + +func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error { + s.mu.Lock() + defer s.mu.Unlock() + s.sent = append(s.sent, m) + return nil +} + +func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.recvIdx >= len(s.recvMsgs) { + return nil, fmt.Errorf("no more recv messages") + } + msg := s.recvMsgs[s.recvIdx] + s.recvIdx++ + return msg, nil +} + +func (s *syncRecordingStream) Context() context.Context { return context.Background() } +func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil } +func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil } +func (s *syncRecordingStream) SetTrailer(metadata.MD) {} +func (s *syncRecordingStream) SendMsg(any) error { return nil } +func (s *syncRecordingStream) RecvMsg(any) error { return nil } + +func ackMsg() *proto.SyncMappingsRequest { + return &proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + } +} + +func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 7 // 3 + 3 + 1 → 3 batches, 3 acks (one per batch, including final) + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &syncRecordingStream{ + recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg(), ackMsg()}, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + + require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches") + + assert.Len(t, stream.sent[0].Mapping, 3) + assert.False(t, stream.sent[0].InitialSyncComplete) + + assert.Len(t, stream.sent[1].Mapping, 3) + assert.False(t, stream.sent[1].InitialSyncComplete) + + assert.Len(t, stream.sent[2].Mapping, 1) + assert.True(t, stream.sent[2].InitialSyncComplete) + + // All 3 acks consumed — including the final batch. + assert.Equal(t, 3, stream.recvIdx) +} + +func TestSendSnapshotSync_SingleBatchWaitsForAck(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 100 + const totalServices = 5 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &syncRecordingStream{ + recvMsgs: []*proto.SyncMappingsRequest{ackMsg()}, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + + require.Len(t, stream.sent, 1) + assert.Len(t, stream.sent[0].Mapping, totalServices) + assert.True(t, stream.sent[0].InitialSyncComplete) + assert.Equal(t, 1, stream.recvIdx, "final batch ack must be consumed") +} + +func TestSendSnapshotSync_EmptySnapshot(t *testing.T) { + const cluster = "cluster.example.com" + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil) + + s := newSnapshotTestServer(t, 500) + s.serviceManager = mgr + + stream := &syncRecordingStream{ + recvMsgs: []*proto.SyncMappingsRequest{ackMsg()}, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + + require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete") + assert.Empty(t, stream.sent[0].Mapping) + assert.True(t, stream.sent[0].InitialSyncComplete) + assert.Equal(t, 1, stream.recvIdx, "empty snapshot ack must be consumed") +} + +func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 4 // 2 batches → 1 ack needed, but we provide none + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + // No acks available — Recv will return error. + stream := &syncRecordingStream{} + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.Error(t, err) + assert.Contains(t, err.Error(), "receive ack") + // First batch should have been sent before the error. + require.Len(t, stream.sent, 1) +} + +func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 4 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + // Send an init message instead of an ack. + stream := &syncRecordingStream{ + recvMsgs: []*proto.SyncMappingsRequest{ + {Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}}, + }, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected ack") +} + +func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) { + // Verify batches are sent strictly sequentially — batch N+1 is not sent + // until the ack for batch N is received, including the final batch. + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 6 // 3 batches, 3 acks + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + var mu sync.Mutex + var events []string + + // Build a stream that logs send/recv events so we can verify ordering. + ackCh := make(chan struct{}, 3) + stream := &orderTrackingStream{ + mu: &mu, + events: &events, + ackCh: ackCh, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + // Feed acks asynchronously after a short delay to simulate real proxy. + go func() { + for range 3 { + time.Sleep(10 * time.Millisecond) + ackCh <- struct{}{} + } + }() + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + + // Expected: send, recv-ack, send, recv-ack, send, recv-ack. + require.Len(t, events, 6) + assert.Equal(t, "send", events[0]) + assert.Equal(t, "recv", events[1]) + assert.Equal(t, "send", events[2]) + assert.Equal(t, "recv", events[3]) + assert.Equal(t, "send", events[4]) + assert.Equal(t, "recv", events[5]) +} + +// orderTrackingStream logs "send" and "recv" events and blocks Recv until +// an ack is signaled via ackCh. +type orderTrackingStream struct { + grpc.ServerStream + mu *sync.Mutex + events *[]string + ackCh chan struct{} +} + +func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error { + s.mu.Lock() + *s.events = append(*s.events, "send") + s.mu.Unlock() + return nil +} + +func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) { + <-s.ackCh + s.mu.Lock() + *s.events = append(*s.events, "recv") + s.mu.Unlock() + return ackMsg(), nil +} + +func (s *orderTrackingStream) Context() context.Context { return context.Background() } +func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil } +func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil } +func (s *orderTrackingStream) SetTrailer(metadata.MD) {} +func (s *orderTrackingStream) SendMsg(any) error { return nil } +func (s *orderTrackingStream) RecvMsg(any) error { return nil } + +func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 4 + const ttl = 100 * time.Millisecond + const ackDelay = 200 * time.Millisecond + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + s.tokenTTL = ttl + + // Build a stream that validates tokens immediately on Send, then + // delays the ack to ensure the next batch's tokens are generated fresh. + var validateErrs []error + ackCh := make(chan struct{}, 2) + stream := &tokenValidatingSyncStream{ + tokenStore: s.tokenStore, + validateErrs: &validateErrs, + ackCh: ackCh, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + go func() { + // Delay first ack so that if tokens were all generated upfront they'd expire. + time.Sleep(ackDelay) + ackCh <- struct{}{} + // Final batch ack — immediate. + ackCh <- struct{}{} + }() + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + require.Empty(t, validateErrs, + "tokens must remain valid: per-batch generation guarantees freshness") +} + +type tokenValidatingSyncStream struct { + grpc.ServerStream + tokenStore *OneTimeTokenStore + validateErrs *[]error + ackCh chan struct{} +} + +func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error { + for _, mapping := range m.Mapping { + if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil { + *s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err)) + } + } + return nil +} + +func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) { + <-s.ackCh + return ackMsg(), nil +} + +func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() } +func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil } +func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil } +func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {} +func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil } +func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil } + +func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) { + stream := &syncRecordingStream{} + conn := &proxyConnection{ + syncStream: stream, + } + + resp := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{ + {Id: "svc-1", AccountId: "acct-1", Domain: "example.com"}, + }, + InitialSyncComplete: true, + } + + err := conn.sendResponse(resp) + require.NoError(t, err) + + require.Len(t, stream.sent, 1) + assert.Len(t, stream.sent[0].Mapping, 1) + assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id) + assert.True(t, stream.sent[0].InitialSyncComplete) +} + +func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) { + stream := &recordingStream{} + conn := &proxyConnection{ + stream: stream, + } + + resp := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{ + {Id: "svc-2", AccountId: "acct-2"}, + }, + } + + err := conn.sendResponse(resp) + require.NoError(t, err) + + require.Len(t, stream.messages, 1) + assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id) +} diff --git a/proxy/server.go b/proxy/server.go index 78149b128..4c3c229ab 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -32,9 +32,11 @@ import ( "go.opentelemetry.io/otel/sdk/metric" "golang.org/x/exp/maps" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + grpcstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/timestamppb" @@ -1145,6 +1147,10 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr Clock: backoff.SystemClock, } + // syncSupported tracks whether management supports SyncMappings. + // Starts true; set to false on the first Unimplemented error so + // subsequent retries skip straight to GetMappingUpdate. + syncSupported := true initialSyncDone := false operation := func() error { @@ -1156,41 +1162,25 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr s.healthChecker.SetManagementConnected(false) } - supportsCrowdSec := s.crowdsecRegistry.Available() - privateCapability := s.Private - // Always true: this build enforces ProxyMapping.private via the auth middleware. - supportsPrivateService := true - mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ - ProxyId: s.ID, - Version: s.Version, - StartedAt: timestamppb.New(s.startTime), - Address: s.ProxyURL, - Capabilities: &proto.ProxyCapabilities{ - SupportsCustomPorts: &s.SupportsCustomPorts, - RequireSubdomain: &s.RequireSubdomain, - SupportsCrowdsec: &supportsCrowdSec, - Private: &privateCapability, - SupportsPrivateService: &supportsPrivateService, - }, - }) - if err != nil { - return fmt.Errorf("create mapping stream: %w", err) + var streamErr error + if syncSupported { + streamErr = s.trySyncMappings(ctx, client, &initialSyncDone) + if isSyncUnimplemented(streamErr) { + syncSupported = false + s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate") + streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone) + } + } else { + streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone) } - if s.healthChecker != nil { - s.healthChecker.SetManagementConnected(true) - } - s.Logger.Debug("management mapping stream established") - - // Stream established — reset backoff so the next failure retries quickly. - bo.Reset() - - streamErr := s.handleMappingStream(ctx, mappingClient, &initialSyncDone, time.Now()) - if s.healthChecker != nil { s.healthChecker.SetManagementConnected(false) } + // Stream established — reset backoff so the next failure retries quickly. + bo.Reset() + if streamErr == nil { return fmt.Errorf("stream closed by server") } @@ -1207,6 +1197,134 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr } } +func (s *Server) proxyCapabilities() *proto.ProxyCapabilities { + supportsCrowdSec := s.crowdsecRegistry.Available() + privateCapability := s.Private + // Always true: this build enforces ProxyMapping.private via the auth middleware. + supportsPrivateService := true + return &proto.ProxyCapabilities{ + SupportsCustomPorts: &s.SupportsCustomPorts, + RequireSubdomain: &s.RequireSubdomain, + SupportsCrowdsec: &supportsCrowdSec, + Private: &privateCapability, + SupportsPrivateService: &supportsPrivateService, + } +} + +func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error { + connectTime := time.Now() + mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: s.ID, + Version: s.Version, + StartedAt: timestamppb.New(s.startTime), + Address: s.ProxyURL, + Capabilities: s.proxyCapabilities(), + }) + if err != nil { + return fmt.Errorf("create mapping stream: %w", err) + } + + if s.healthChecker != nil { + s.healthChecker.SetManagementConnected(true) + } + s.Logger.Debug("management mapping stream established (GetMappingUpdate)") + + return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime) +} + +func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error { + connectTime := time.Now() + stream, err := client.SyncMappings(ctx) + if err != nil { + return fmt.Errorf("create sync stream: %w", err) + } + + if err := stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: s.ID, + Version: s.Version, + StartedAt: timestamppb.New(s.startTime), + Address: s.ProxyURL, + Capabilities: s.proxyCapabilities(), + }, + }, + }); err != nil { + return fmt.Errorf("send sync init: %w", err) + } + + if s.healthChecker != nil { + s.healthChecker.SetManagementConnected(true) + } + s.Logger.Debug("management mapping stream established (SyncMappings)") + + return s.handleSyncMappingsStream(ctx, stream, initialSyncDone, connectTime) +} + +func isSyncUnimplemented(err error) bool { + if err == nil { + return false + } + st, ok := grpcstatus.FromError(err) + return ok && st.Code() == codes.Unimplemented +} + +// handleSyncMappingsStream consumes batches from a bidirectional SyncMappings +// stream, sending an ack after each batch is fully processed. Management waits +// for the ack before sending the next batch, providing application-level +// back-pressure. +func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, _ time.Time) error { + select { + case <-s.routerReady: + case <-ctx.Done(): + return ctx.Err() + } + + var snapshotIDs map[types.ServiceID]struct{} + if !*initialSyncDone { + snapshotIDs = make(map[types.ServiceID]struct{}) + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + msg, err := stream.Recv() + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return fmt.Errorf("receive msg: %w", err) + } + s.Logger.Debug("Received mapping update, starting processing") + s.processMappings(ctx, msg.GetMapping()) + s.Logger.Debug("Processing mapping update completed") + + if !*initialSyncDone { + for _, m := range msg.GetMapping() { + snapshotIDs[types.ServiceID(m.GetId())] = struct{}{} + } + if msg.GetInitialSyncComplete() { + s.reconcileSnapshot(ctx, snapshotIDs) + snapshotIDs = nil + if s.healthChecker != nil { + s.healthChecker.SetInitialSyncComplete() + } + *initialSyncDone = true + s.Logger.Info("Initial mapping sync complete") + } + } + + if err := stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }); err != nil { + return fmt.Errorf("send ack: %w", err) + } + } + } +} + func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, _ time.Time) error { select { case <-s.routerReady: diff --git a/proxy/sync_mappings_test.go b/proxy/sync_mappings_test.go new file mode 100644 index 000000000..801587e4c --- /dev/null +++ b/proxy/sync_mappings_test.go @@ -0,0 +1,525 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + grpcstatus "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestIntegration_SyncMappings_HappyPath(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + // Send init. + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "sync-proxy-1", + Version: "test-v1", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + mappingsByID := make(map[string]*proto.ProxyMapping) + for { + msg, err := stream.Recv() + require.NoError(t, err) + for _, m := range msg.GetMapping() { + mappingsByID[m.GetId()] = m + } + + // Ack every batch. + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + + assert.Len(t, mappingsByID, 2, "Should receive 2 mappings") + + rp1 := mappingsByID["rp-1"] + require.NotNil(t, rp1) + assert.Equal(t, "app1.test.proxy.io", rp1.GetDomain()) + assert.Equal(t, "test-account-1", rp1.GetAccountId()) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, rp1.GetType()) + assert.NotEmpty(t, rp1.GetAuthToken(), "Should have auth token") + + rp2 := mappingsByID["rp-2"] + require.NotNil(t, rp2) + assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain()) +} + +func TestIntegration_SyncMappings_BackPressure(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + // Add enough services to guarantee multiple batches (default batch size 500). + addServicesToStore(t, setup, 600, "test.proxy.io") + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "sync-proxy-backpressure", + Version: "test-v1", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + // Strategy: receive batch 1, then hold for a significant delay before + // acking. If back-pressure works, batch 2 cannot arrive until after + // the ack is sent — so its receive timestamp must be >= the ack + // timestamp. If management were fire-and-forget, all batches would + // already be buffered in the gRPC transport and batch 2 would arrive + // well before the ack time. + const ackDelay = 300 * time.Millisecond + + type batchEvent struct { + recvAt time.Time + ackAt time.Time + count int + } + var batches []batchEvent + var totalMappings int + + for { + msg, err := stream.Recv() + require.NoError(t, err) + + recvAt := time.Now() + totalMappings += len(msg.GetMapping()) + + // Delay the ack on non-final batches to create a measurable gap. + if !msg.GetInitialSyncComplete() { + time.Sleep(ackDelay) + } + + ackAt := time.Now() + batches = append(batches, batchEvent{ + recvAt: recvAt, + ackAt: ackAt, + count: len(msg.GetMapping()), + }) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + + // 2 original + 600 added = 602 services total. + assert.Equal(t, 602, totalMappings, "should receive all 602 mappings") + require.GreaterOrEqual(t, len(batches), 2, "need at least 2 batches to verify back-pressure") + + // For every batch after the first, its receive time must be after the + // previous batch's ack time. This proves management waited for the ack + // before sending the next batch. + for i := 1; i < len(batches); i++ { + prevAckAt := batches[i-1].ackAt + thisRecvAt := batches[i].recvAt + assert.True(t, !thisRecvAt.Before(prevAckAt), + "batch %d received at %v, but batch %d was acked at %v — "+ + "management sent the next batch before receiving the ack", + i, thisRecvAt, i-1, prevAckAt) + } +} + +func TestIntegration_SyncMappings_IncrementalUpdate(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "sync-proxy-incremental", + Version: "test-v1", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + // Drain initial snapshot. + for { + msg, err := stream.Recv() + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + + // Now send an incremental update via the management server. + setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: "rp-1", + AccountId: "test-account-1", + Domain: "app1.test.proxy.io", + }}, + }) + + // Receive the incremental update on the sync stream. + msg, err := stream.Recv() + require.NoError(t, err) + require.NotEmpty(t, msg.GetMapping()) + assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId()) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.GetMapping()[0].GetType()) +} + +func TestIntegration_SyncMappings_MixedProxyVersions(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Old proxy uses GetMappingUpdate. + legacyStream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "legacy-proxy", + Version: "old-v1", + Address: "test.proxy.io", + }) + require.NoError(t, err) + + var legacyMappings []*proto.ProxyMapping + for { + msg, err := legacyStream.Recv() + require.NoError(t, err) + legacyMappings = append(legacyMappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } + } + + // New proxy uses SyncMappings. + syncStream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = syncStream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "new-proxy", + Version: "new-v2", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + var syncMappings []*proto.ProxyMapping + for { + msg, err := syncStream.Recv() + require.NoError(t, err) + syncMappings = append(syncMappings, msg.GetMapping()...) + + err = syncStream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + + // Both should receive the same set of mappings. + assert.Equal(t, len(legacyMappings), len(syncMappings), + "legacy and sync proxies should receive the same number of mappings") + + legacyIDs := make(map[string]bool) + for _, m := range legacyMappings { + legacyIDs[m.GetId()] = true + } + for _, m := range syncMappings { + assert.True(t, legacyIDs[m.GetId()], + "mapping %s should be present in both streams", m.GetId()) + } + + // Both proxies should be connected. + proxies := setup.proxyService.GetConnectedProxies() + assert.Contains(t, proxies, "legacy-proxy") + assert.Contains(t, proxies, "new-proxy") + + // Both should receive incremental updates. + setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: "rp-1", + AccountId: "test-account-1", + Domain: "app1.test.proxy.io", + }}, + }) + + // Legacy proxy receives via GetMappingUpdateResponse. + legacyMsg, err := legacyStream.Recv() + require.NoError(t, err) + assert.Equal(t, "rp-1", legacyMsg.GetMapping()[0].GetId()) + + // Sync proxy receives via SyncMappingsResponse. + syncMsg, err := syncStream.Recv() + require.NoError(t, err) + assert.Equal(t, "rp-1", syncMsg.GetMapping()[0].GetId()) +} + +func TestIntegration_SyncMappings_Reconnect(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + proxyID := "sync-proxy-reconnect" + + receiveMappings := func() []*proto.ProxyMapping { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: proxyID, + Version: "test-v1", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + var mappings []*proto.ProxyMapping + for { + msg, err := stream.Recv() + require.NoError(t, err) + mappings = append(mappings, msg.GetMapping()...) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + return mappings + } + + first := receiveMappings() + time.Sleep(100 * time.Millisecond) + second := receiveMappings() + + assert.Equal(t, len(first), len(second), + "should receive same mappings on reconnect") + + firstIDs := make(map[string]bool) + for _, m := range first { + firstIDs[m.GetId()] = true + } + for _, m := range second { + assert.True(t, firstIDs[m.GetId()], + "mapping %s should be present in both connections", m.GetId()) + } +} + +// --- Fallback tests: old management returns Unimplemented --- + +// unimplementedProxyServer embeds UnimplementedProxyServiceServer so +// SyncMappings returns codes.Unimplemented while GetMappingUpdate works. +type unimplementedSyncServer struct { + proto.UnimplementedProxyServiceServer + getMappingCalls atomic.Int32 +} + +func (s *unimplementedSyncServer) GetMappingUpdate(_ *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error { + s.getMappingCalls.Add(1) + return stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"}}, + InitialSyncComplete: true, + }) +} + +func TestIntegration_FallbackToGetMappingUpdate(t *testing.T) { + // Start a gRPC server that does NOT implement SyncMappings. + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + srv := &unimplementedSyncServer{} + grpcServer := grpc.NewServer() + proto.RegisterProxyServiceServer(grpcServer, srv) + go func() { _ = grpcServer.Serve(lis) }() + defer grpcServer.GracefulStop() + + conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + // Try SyncMappings — should get Unimplemented. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "test-proxy", + Address: "test.example.com", + }, + }, + }) + require.NoError(t, err) + + _, err = stream.Recv() + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unimplemented, st.Code(), + "unimplemented SyncMappings should return Unimplemented code") + + // isSyncUnimplemented should detect this. + assert.True(t, isSyncUnimplemented(err)) + + // The actual fallback: GetMappingUpdate should work. + legacyStream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "test-proxy", + Address: "test.example.com", + }) + require.NoError(t, err) + + msg, err := legacyStream.Recv() + require.NoError(t, err) + assert.True(t, msg.GetInitialSyncComplete()) + assert.Len(t, msg.GetMapping(), 1) + assert.Equal(t, int32(1), srv.getMappingCalls.Load()) +} + +func TestIsSyncUnimplemented(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil error", nil, false}, + {"non-grpc error", errors.New("random"), false}, + {"grpc internal", grpcstatus.Error(codes.Internal, "fail"), false}, + {"grpc unavailable", grpcstatus.Error(codes.Unavailable, "fail"), false}, + {"grpc unimplemented", grpcstatus.Error(codes.Unimplemented, "method not found"), true}, + { + "wrapped unimplemented", + fmt.Errorf("create sync stream: %w", grpcstatus.Error(codes.Unimplemented, "nope")), + // grpc/status.FromError unwraps in recent versions of grpc-go. + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isSyncUnimplemented(tt.err)) + }) + } +} + +// addServicesToStore adds n extra services to the test store for the given cluster. +func addServicesToStore(t *testing.T, setup *integrationTestSetup, n int, cluster string) { + t.Helper() + ctx := context.Background() + for i := 0; i < n; i++ { + svc := &service.Service{ + ID: fmt.Sprintf("extra-svc-%d", i), + AccountID: "test-account-1", + Name: fmt.Sprintf("Extra Service %d", i), + Domain: fmt.Sprintf("extra-%d.test.proxy.io", i), + ProxyCluster: cluster, + Enabled: true, + Targets: []*service.Target{{ + Path: strPtr("/"), + Host: fmt.Sprintf("10.0.1.%d", i%256), + Port: 8080, + Protocol: "http", + TargetId: fmt.Sprintf("peer-extra-%d", i), + TargetType: "peer", + Enabled: true, + }}, + } + require.NoError(t, setup.store.CreateService(ctx, svc)) + } +} diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index a512214e5..67f9a10fb 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -2291,6 +2291,269 @@ func (x *ValidateTunnelPeerResponse) GetPeerGroupNames() []string { return nil } +// SyncMappingsRequest is sent by the proxy on the bidirectional SyncMappings +// stream. The first message MUST be an init; all subsequent messages MUST be +// acks. +type SyncMappingsRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Msg: + // + // *SyncMappingsRequest_Init + // *SyncMappingsRequest_Ack + Msg isSyncMappingsRequest_Msg `protobuf_oneof:"msg"` +} + +func (x *SyncMappingsRequest) Reset() { + *x = SyncMappingsRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[28] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMappingsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMappingsRequest) ProtoMessage() {} + +func (x *SyncMappingsRequest) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[28] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMappingsRequest.ProtoReflect.Descriptor instead. +func (*SyncMappingsRequest) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{28} +} + +func (m *SyncMappingsRequest) GetMsg() isSyncMappingsRequest_Msg { + if m != nil { + return m.Msg + } + return nil +} + +func (x *SyncMappingsRequest) GetInit() *SyncMappingsInit { + if x, ok := x.GetMsg().(*SyncMappingsRequest_Init); ok { + return x.Init + } + return nil +} + +func (x *SyncMappingsRequest) GetAck() *SyncMappingsAck { + if x, ok := x.GetMsg().(*SyncMappingsRequest_Ack); ok { + return x.Ack + } + return nil +} + +type isSyncMappingsRequest_Msg interface { + isSyncMappingsRequest_Msg() +} + +type SyncMappingsRequest_Init struct { + Init *SyncMappingsInit `protobuf:"bytes,1,opt,name=init,proto3,oneof"` +} + +type SyncMappingsRequest_Ack struct { + Ack *SyncMappingsAck `protobuf:"bytes,2,opt,name=ack,proto3,oneof"` +} + +func (*SyncMappingsRequest_Init) isSyncMappingsRequest_Msg() {} + +func (*SyncMappingsRequest_Ack) isSyncMappingsRequest_Msg() {} + +// SyncMappingsInit is the first message on the stream, carrying the same +// identification fields as GetMappingUpdateRequest. +type SyncMappingsInit struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` + Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` + Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` +} + +func (x *SyncMappingsInit) Reset() { + *x = SyncMappingsInit{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[29] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMappingsInit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMappingsInit) ProtoMessage() {} + +func (x *SyncMappingsInit) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[29] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMappingsInit.ProtoReflect.Descriptor instead. +func (*SyncMappingsInit) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{29} +} + +func (x *SyncMappingsInit) GetProxyId() string { + if x != nil { + return x.ProxyId + } + return "" +} + +func (x *SyncMappingsInit) GetVersion() string { + if x != nil { + return x.Version + } + return "" +} + +func (x *SyncMappingsInit) GetStartedAt() *timestamppb.Timestamp { + if x != nil { + return x.StartedAt + } + return nil +} + +func (x *SyncMappingsInit) GetAddress() string { + if x != nil { + return x.Address + } + return "" +} + +func (x *SyncMappingsInit) GetCapabilities() *ProxyCapabilities { + if x != nil { + return x.Capabilities + } + return nil +} + +// SyncMappingsAck is sent by the proxy after it has fully processed a batch. +// Management waits for this before sending the next batch. +type SyncMappingsAck struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SyncMappingsAck) Reset() { + *x = SyncMappingsAck{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[30] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMappingsAck) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMappingsAck) ProtoMessage() {} + +func (x *SyncMappingsAck) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[30] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMappingsAck.ProtoReflect.Descriptor instead. +func (*SyncMappingsAck) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{30} +} + +// SyncMappingsResponse is a batch of mappings sent by management. +// Identical semantics to GetMappingUpdateResponse. +type SyncMappingsResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Mapping []*ProxyMapping `protobuf:"bytes,1,rep,name=mapping,proto3" json:"mapping,omitempty"` + // initial_sync_complete is set on the last message of the initial snapshot. + InitialSyncComplete bool `protobuf:"varint,2,opt,name=initial_sync_complete,json=initialSyncComplete,proto3" json:"initial_sync_complete,omitempty"` +} + +func (x *SyncMappingsResponse) Reset() { + *x = SyncMappingsResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[31] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMappingsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMappingsResponse) ProtoMessage() {} + +func (x *SyncMappingsResponse) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[31] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMappingsResponse.ProtoReflect.Descriptor instead. +func (*SyncMappingsResponse) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{31} +} + +func (x *SyncMappingsResponse) GetMapping() []*ProxyMapping { + if x != nil { + return x.Mapping + } + return nil +} + +func (x *SyncMappingsResponse) GetInitialSyncComplete() bool { + if x != nil { + return x.InitialSyncComplete + } + return false +} + var File_proxy_service_proto protoreflect.FileDescriptor var file_proxy_service_proto_rawDesc = []byte{ @@ -2628,37 +2891,74 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, - 0x73, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, - 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, - 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, - 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, - 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, - 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, - 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, - 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, - 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, - 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, - 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, - 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, - 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, - 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, - 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, - 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, - 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, - 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, - 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, - 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, - 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, - 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, - 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, - 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, 0xe1, 0x05, 0x0a, 0x0c, 0x50, - 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, - 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, - 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, + 0x73, 0x22, 0x81, 0x01, 0x0a, 0x13, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x32, 0x0a, 0x04, 0x69, 0x6e, 0x69, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, + 0x73, 0x49, 0x6e, 0x69, 0x74, 0x48, 0x00, 0x52, 0x04, 0x69, 0x6e, 0x69, 0x74, 0x12, 0x2f, 0x0a, + 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x48, 0x00, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x42, 0x05, + 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xdf, 0x01, 0x0a, 0x10, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, + 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, + 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, + 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, + 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, + 0x74, 0x69, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, + 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, + 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, 0x11, 0x0a, 0x0f, 0x53, 0x79, 0x6e, 0x63, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x22, 0x7e, 0x0a, 0x14, 0x53, 0x79, + 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, + 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, + 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, + 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, + 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, + 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, + 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, + 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, + 0x2a, 0x46, 0x0a, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, + 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, + 0x49, 0x54, 0x45, 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, + 0x15, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, + 0x45, 0x53, 0x45, 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, + 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, + 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, + 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, + 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, + 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, + 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, + 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, + 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, + 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, + 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, + 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, + 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, + 0x52, 0x10, 0x05, 0x32, 0xb8, 0x06, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, + 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x55, 0x0a, 0x0c, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, + 0x70, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, @@ -2714,7 +3014,7 @@ func file_proxy_service_proto_rawDescGZIP() []byte { } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 30) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 34) var file_proxy_service_proto_goTypes = []interface{}{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType (PathRewriteMode)(0), // 1: management.PathRewriteMode @@ -2747,19 +3047,23 @@ var file_proxy_service_proto_goTypes = []interface{}{ (*ValidateSessionResponse)(nil), // 28: management.ValidateSessionResponse (*ValidateTunnelPeerRequest)(nil), // 29: management.ValidateTunnelPeerRequest (*ValidateTunnelPeerResponse)(nil), // 30: management.ValidateTunnelPeerResponse - nil, // 31: management.PathTargetOptions.CustomHeadersEntry - nil, // 32: management.AccessLog.MetadataEntry - (*timestamppb.Timestamp)(nil), // 33: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 34: google.protobuf.Duration + (*SyncMappingsRequest)(nil), // 31: management.SyncMappingsRequest + (*SyncMappingsInit)(nil), // 32: management.SyncMappingsInit + (*SyncMappingsAck)(nil), // 33: management.SyncMappingsAck + (*SyncMappingsResponse)(nil), // 34: management.SyncMappingsResponse + nil, // 35: management.PathTargetOptions.CustomHeadersEntry + nil, // 36: management.AccessLog.MetadataEntry + (*timestamppb.Timestamp)(nil), // 37: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 38: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 33, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 37, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities 11, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 34, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 38, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode - 31, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 34, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 35, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 38, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions 8, // 8: management.Authentication.header_auths:type_name -> management.HeaderAuth 0, // 9: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType @@ -2767,34 +3071,41 @@ var file_proxy_service_proto_depIdxs = []int32{ 9, // 11: management.ProxyMapping.auth:type_name -> management.Authentication 10, // 12: management.ProxyMapping.access_restrictions:type_name -> management.AccessRestrictions 14, // 13: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 33, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 32, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry + 37, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 36, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry 17, // 16: management.AuthenticateRequest.password:type_name -> management.PasswordRequest 18, // 17: management.AuthenticateRequest.pin:type_name -> management.PinRequest 16, // 18: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest 2, // 19: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus 21, // 20: management.SendStatusUpdateRequest.inbound_listener:type_name -> management.ProxyInboundListener - 4, // 21: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 12, // 22: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 15, // 23: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 20, // 24: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 23, // 25: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 25, // 26: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 27, // 27: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 29, // 28: management.ProxyService.ValidateTunnelPeer:input_type -> management.ValidateTunnelPeerRequest - 5, // 29: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 13, // 30: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 19, // 31: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 22, // 32: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 24, // 33: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 26, // 34: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 28, // 35: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 30, // 36: management.ProxyService.ValidateTunnelPeer:output_type -> management.ValidateTunnelPeerResponse - 29, // [29:37] is the sub-list for method output_type - 21, // [21:29] is the sub-list for method input_type - 21, // [21:21] is the sub-list for extension type_name - 21, // [21:21] is the sub-list for extension extendee - 0, // [0:21] is the sub-list for field type_name + 32, // 21: management.SyncMappingsRequest.init:type_name -> management.SyncMappingsInit + 33, // 22: management.SyncMappingsRequest.ack:type_name -> management.SyncMappingsAck + 37, // 23: management.SyncMappingsInit.started_at:type_name -> google.protobuf.Timestamp + 3, // 24: management.SyncMappingsInit.capabilities:type_name -> management.ProxyCapabilities + 11, // 25: management.SyncMappingsResponse.mapping:type_name -> management.ProxyMapping + 4, // 26: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 31, // 27: management.ProxyService.SyncMappings:input_type -> management.SyncMappingsRequest + 12, // 28: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 15, // 29: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 20, // 30: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 23, // 31: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 25, // 32: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 27, // 33: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 29, // 34: management.ProxyService.ValidateTunnelPeer:input_type -> management.ValidateTunnelPeerRequest + 5, // 35: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 34, // 36: management.ProxyService.SyncMappings:output_type -> management.SyncMappingsResponse + 13, // 37: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 19, // 38: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 22, // 39: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 24, // 40: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 26, // 41: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 28, // 42: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 30, // 43: management.ProxyService.ValidateTunnelPeer:output_type -> management.ValidateTunnelPeerResponse + 35, // [35:44] is the sub-list for method output_type + 26, // [26:35] is the sub-list for method input_type + 26, // [26:26] is the sub-list for extension type_name + 26, // [26:26] is the sub-list for extension extendee + 0, // [0:26] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -3139,6 +3450,54 @@ func file_proxy_service_proto_init() { return nil } } + file_proxy_service_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsInit); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsAck); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} file_proxy_service_proto_msgTypes[12].OneofWrappers = []interface{}{ @@ -3148,13 +3507,17 @@ func file_proxy_service_proto_init() { } file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} file_proxy_service_proto_msgTypes[21].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[28].OneofWrappers = []interface{}{ + (*SyncMappingsRequest_Init)(nil), + (*SyncMappingsRequest_Ack)(nil), + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, NumEnums: 3, - NumMessages: 30, + NumMessages: 34, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index e0763e890..eb5057fc5 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -12,6 +12,15 @@ import "google/protobuf/timestamp.proto"; service ProxyService { rpc GetMappingUpdate(GetMappingUpdateRequest) returns (stream GetMappingUpdateResponse); + // SyncMappings is a bidirectional stream that replaces GetMappingUpdate for + // new proxies. The proxy sends an initial SyncMappingsRequest to start the + // stream and then sends an ack after each batch is fully processed. + // Management waits for the ack before sending the next batch, providing + // application-level back-pressure during large initial syncs. + // Old proxies continue using GetMappingUpdate; old management servers + // return Unimplemented for this RPC and proxies fall back. + rpc SyncMappings(stream SyncMappingsRequest) returns (stream SyncMappingsResponse); + rpc SendAccessLog(SendAccessLogRequest) returns (SendAccessLogResponse); rpc Authenticate(AuthenticateRequest) returns (AuthenticateResponse); @@ -335,3 +344,35 @@ message ValidateTunnelPeerResponse { repeated string peer_group_names = 7; } +// SyncMappingsRequest is sent by the proxy on the bidirectional SyncMappings +// stream. The first message MUST be an init; all subsequent messages MUST be +// acks. +message SyncMappingsRequest { + oneof msg { + SyncMappingsInit init = 1; + SyncMappingsAck ack = 2; + } +} + +// SyncMappingsInit is the first message on the stream, carrying the same +// identification fields as GetMappingUpdateRequest. +message SyncMappingsInit { + string proxy_id = 1; + string version = 2; + google.protobuf.Timestamp started_at = 3; + string address = 4; + ProxyCapabilities capabilities = 5; +} + +// SyncMappingsAck is sent by the proxy after it has fully processed a batch. +// Management waits for this before sending the next batch. +message SyncMappingsAck {} + +// SyncMappingsResponse is a batch of mappings sent by management. +// Identical semantics to GetMappingUpdateResponse. +message SyncMappingsResponse { + repeated ProxyMapping mapping = 1; + // initial_sync_complete is set on the last message of the initial snapshot. + bool initial_sync_complete = 2; +} + diff --git a/shared/management/proto/proxy_service_grpc.pb.go b/shared/management/proto/proxy_service_grpc.pb.go index 9f28aaaaa..30eabc296 100644 --- a/shared/management/proto/proxy_service_grpc.pb.go +++ b/shared/management/proto/proxy_service_grpc.pb.go @@ -19,6 +19,14 @@ const _ = grpc.SupportPackageIsVersion7 // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type ProxyServiceClient interface { GetMappingUpdate(ctx context.Context, in *GetMappingUpdateRequest, opts ...grpc.CallOption) (ProxyService_GetMappingUpdateClient, error) + // SyncMappings is a bidirectional stream that replaces GetMappingUpdate for + // new proxies. The proxy sends an initial SyncMappingsRequest to start the + // stream and then sends an ack after each batch is fully processed. + // Management waits for the ack before sending the next batch, providing + // application-level back-pressure during large initial syncs. + // Old proxies continue using GetMappingUpdate; old management servers + // return Unimplemented for this RPC and proxies fall back. + SyncMappings(ctx context.Context, opts ...grpc.CallOption) (ProxyService_SyncMappingsClient, error) SendAccessLog(ctx context.Context, in *SendAccessLogRequest, opts ...grpc.CallOption) (*SendAccessLogResponse, error) Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*AuthenticateResponse, error) SendStatusUpdate(ctx context.Context, in *SendStatusUpdateRequest, opts ...grpc.CallOption) (*SendStatusUpdateResponse, error) @@ -77,6 +85,37 @@ func (x *proxyServiceGetMappingUpdateClient) Recv() (*GetMappingUpdateResponse, return m, nil } +func (c *proxyServiceClient) SyncMappings(ctx context.Context, opts ...grpc.CallOption) (ProxyService_SyncMappingsClient, error) { + stream, err := c.cc.NewStream(ctx, &ProxyService_ServiceDesc.Streams[1], "/management.ProxyService/SyncMappings", opts...) + if err != nil { + return nil, err + } + x := &proxyServiceSyncMappingsClient{stream} + return x, nil +} + +type ProxyService_SyncMappingsClient interface { + Send(*SyncMappingsRequest) error + Recv() (*SyncMappingsResponse, error) + grpc.ClientStream +} + +type proxyServiceSyncMappingsClient struct { + grpc.ClientStream +} + +func (x *proxyServiceSyncMappingsClient) Send(m *SyncMappingsRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *proxyServiceSyncMappingsClient) Recv() (*SyncMappingsResponse, error) { + m := new(SyncMappingsResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + func (c *proxyServiceClient) SendAccessLog(ctx context.Context, in *SendAccessLogRequest, opts ...grpc.CallOption) (*SendAccessLogResponse, error) { out := new(SendAccessLogResponse) err := c.cc.Invoke(ctx, "/management.ProxyService/SendAccessLog", in, out, opts...) @@ -145,6 +184,14 @@ func (c *proxyServiceClient) ValidateTunnelPeer(ctx context.Context, in *Validat // for forward compatibility type ProxyServiceServer interface { GetMappingUpdate(*GetMappingUpdateRequest, ProxyService_GetMappingUpdateServer) error + // SyncMappings is a bidirectional stream that replaces GetMappingUpdate for + // new proxies. The proxy sends an initial SyncMappingsRequest to start the + // stream and then sends an ack after each batch is fully processed. + // Management waits for the ack before sending the next batch, providing + // application-level back-pressure during large initial syncs. + // Old proxies continue using GetMappingUpdate; old management servers + // return Unimplemented for this RPC and proxies fall back. + SyncMappings(ProxyService_SyncMappingsServer) error SendAccessLog(context.Context, *SendAccessLogRequest) (*SendAccessLogResponse, error) Authenticate(context.Context, *AuthenticateRequest) (*AuthenticateResponse, error) SendStatusUpdate(context.Context, *SendStatusUpdateRequest) (*SendStatusUpdateResponse, error) @@ -171,6 +218,9 @@ type UnimplementedProxyServiceServer struct { func (UnimplementedProxyServiceServer) GetMappingUpdate(*GetMappingUpdateRequest, ProxyService_GetMappingUpdateServer) error { return status.Errorf(codes.Unimplemented, "method GetMappingUpdate not implemented") } +func (UnimplementedProxyServiceServer) SyncMappings(ProxyService_SyncMappingsServer) error { + return status.Errorf(codes.Unimplemented, "method SyncMappings not implemented") +} func (UnimplementedProxyServiceServer) SendAccessLog(context.Context, *SendAccessLogRequest) (*SendAccessLogResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method SendAccessLog not implemented") } @@ -226,6 +276,32 @@ func (x *proxyServiceGetMappingUpdateServer) Send(m *GetMappingUpdateResponse) e return x.ServerStream.SendMsg(m) } +func _ProxyService_SyncMappings_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ProxyServiceServer).SyncMappings(&proxyServiceSyncMappingsServer{stream}) +} + +type ProxyService_SyncMappingsServer interface { + Send(*SyncMappingsResponse) error + Recv() (*SyncMappingsRequest, error) + grpc.ServerStream +} + +type proxyServiceSyncMappingsServer struct { + grpc.ServerStream +} + +func (x *proxyServiceSyncMappingsServer) Send(m *SyncMappingsResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *proxyServiceSyncMappingsServer) Recv() (*SyncMappingsRequest, error) { + m := new(SyncMappingsRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + func _ProxyService_SendAccessLog_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(SendAccessLogRequest) if err := dec(in); err != nil { @@ -394,6 +470,12 @@ var ProxyService_ServiceDesc = grpc.ServiceDesc{ Handler: _ProxyService_GetMappingUpdate_Handler, ServerStreams: true, }, + { + StreamName: "SyncMappings", + Handler: _ProxyService_SyncMappings_Handler, + ServerStreams: true, + ClientStreams: true, + }, }, Metadata: "proxy_service.proto", }