rename reverse proxy to services

This commit is contained in:
pascal
2026-02-11 21:39:51 +01:00
parent ebb1f4007d
commit 08ab1e3478
32 changed files with 1125 additions and 1471 deletions

View File

@@ -176,39 +176,39 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy.
// Only reverse proxies matching the proxy's cluster address are sent.
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
reverseProxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
if err != nil {
return fmt.Errorf("get reverse proxies from store: %w", err)
return fmt.Errorf("get services from store: %w", err)
}
proxyClusterAddr := extractClusterAddr(conn.address)
for _, rp := range reverseProxies {
if !rp.Enabled {
for _, service := range services {
if !service.Enabled {
continue
}
if rp.ProxyCluster != "" && proxyClusterAddr != "" && rp.ProxyCluster != proxyClusterAddr {
if service.ProxyCluster != "" && proxyClusterAddr != "" && service.ProxyCluster != proxyClusterAddr {
continue
}
// Generate one-time authentication token for each proxy in the snapshot
// Generate one-time authentication token for each service in the snapshot
// Tokens are not persistent on the proxy, so we need to generate new ones on reconnection
token, err := s.tokenStore.GenerateToken(rp.AccountID, rp.ID, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
if err != nil {
log.WithFields(log.Fields{
"proxy": rp.Name,
"account": rp.AccountID,
"service": service.Name,
"account": service.AccountID,
}).WithError(err).Error("Failed to generate auth token for snapshot")
continue
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
rp.ToProtoMapping(
service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
token,
s.GetOIDCValidationConfig(),
@@ -260,10 +260,10 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
accessLog := req.GetLog()
fields := log.Fields{
"reverse_proxy_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
"host": accessLog.GetHost(),
"source_ip": accessLog.GetSourceIp(),
"service_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
"host": accessLog.GetHost(),
"source_ip": accessLog.GetSourceIp(),
}
if mechanism := accessLog.GetAuthMechanism(); mechanism != "" {
fields["auth_mechanism"] = mechanism
@@ -292,18 +292,18 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
return &proto.SendAccessLogResponse{}, nil
}
// SendReverseProxyUpdate broadcasts a reverse proxy update to all connected proxies.
// Management should call this when reverse proxies are created/updated/removed
func (s *ProxyServiceServer) SendReverseProxyUpdate(update *proto.ProxyMapping) {
// Send it to all connected proxies
log.Debugf("Broadcasting reverse proxy update to all connected proxies")
// SendServiceUpdate broadcasts a service update to all connected proxy servers.
// Management should call this when services are created/updated/removed
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
// Send it to all connected proxy servers
log.Debugf("Broadcasting service update to all connected proxy servers")
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
select {
case conn.sendChan <- update:
log.Debugf("Sent reverse proxy update with id %s to proxy %s", update.Id, conn.proxyID)
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
default:
log.Warnf("Failed to send reverse proxy update to proxy %s (channel full)", conn.proxyID)
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
}
return true
})
@@ -366,11 +366,11 @@ func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
}
}
// SendReverseProxyUpdateToCluster sends a reverse proxy update to all proxies in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxies (backward compatibility).
func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
if clusterAddr == "" {
s.SendReverseProxyUpdate(update)
s.SendServiceUpdate(update)
return
}
@@ -380,16 +380,16 @@ func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.Proxy
return
}
log.Debugf("Sending reverse proxy update to cluster %s", clusterAddr)
log.Debugf("Sending service update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxyID := key.(string)
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
select {
case conn.sendChan <- update:
log.Debugf("Sent reverse proxy update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.Warnf("Failed to send reverse proxy update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
return true
@@ -424,10 +424,10 @@ func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
proxy, err := s.reverseProxyManager.GetProxyByID(ctx, req.GetAccountId(), req.GetId())
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
// TODO: log the error
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "failed to get service from store: %v", err)
}
var authenticated bool
@@ -435,7 +435,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
var method proxyauth.Method
switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin:
auth := proxy.Auth.PinAuth
auth := service.Auth.PinAuth
if auth == nil || !auth.Enabled {
// TODO: log
// Break here and use the default authenticated == false.
@@ -454,7 +454,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
userId = "pin-user"
method = proxyauth.MethodPIN
case *proto.AuthenticateRequest_Password:
auth := proxy.Auth.PasswordAuth
auth := service.Auth.PasswordAuth
if auth == nil || !auth.Enabled {
// TODO: log
// Break here and use the default authenticated == false.
@@ -475,11 +475,11 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
}
var token string
if authenticated && proxy.SessionPrivateKey != "" {
if authenticated && service.SessionPrivateKey != "" {
token, err = sessionkey.SignToken(
proxy.SessionPrivateKey,
service.SessionPrivateKey,
userId,
proxy.Domain,
service.Domain,
method,
proxyauth.DefaultSessionExpiry,
)
@@ -498,45 +498,45 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
// SendStatusUpdate handles status updates from proxy clients
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
accountID := req.GetAccountId()
reverseProxyID := req.GetReverseProxyId()
serviceID := req.GetServiceId()
protoStatus := req.GetStatus()
certificateIssued := req.GetCertificateIssued()
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"service_id": serviceID,
"account_id": accountID,
"status": protoStatus,
"certificate_issued": certificateIssued,
"error_message": req.GetErrorMessage(),
}).Debug("Status update from proxy")
}).Debug("Status update from proxy server")
if reverseProxyID == "" || accountID == "" {
return nil, status.Errorf(codes.InvalidArgument, "reverse_proxy_id and account_id are required")
if serviceID == "" || accountID == "" {
return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required")
}
if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, reverseProxyID); err != nil {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("Failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "failed to update certificate timestamp: %v", err)
}
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"service_id": serviceID,
"account_id": accountID,
}).Info("Certificate issued timestamp updated")
}
internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, reverseProxyID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("Failed to set proxy status")
return nil, status.Errorf(codes.Internal, "failed to update proxy status: %v", err)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("Failed to set service status")
return nil, status.Errorf(codes.Internal, "failed to update service status: %v", err)
}
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"status": internalStatus,
}).Info("Proxy status updated")
"service_id": serviceID,
"account_id": accountID,
"status": internalStatus,
}).Info("Service status updated")
return &proto.SendStatusUpdateResponse{}, nil
}
@@ -563,30 +563,30 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStat
// CreateProxyPeer handles proxy peer creation with one-time token authentication
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
reverseProxyID := req.GetReverseProxyId()
serviceID := req.GetServiceId()
accountID := req.GetAccountId()
token := req.GetToken()
cluster := req.GetCluster()
key := req.WireguardPublicKey
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"cluster": cluster,
"service_id": serviceID,
"account_id": accountID,
"cluster": cluster,
}).Debug("CreateProxyPeer request received")
if reverseProxyID == "" || accountID == "" || token == "" {
if serviceID == "" || accountID == "" || token == "" {
log.Warn("CreateProxyPeer: missing required fields")
return &proto.CreateProxyPeerResponse{
Success: false,
ErrorMessage: strPtr("missing required fields: reverse_proxy_id, account_id, and token are required"),
ErrorMessage: strPtr("missing required fields: service_id, account_id, and token are required"),
}, nil
}
if err := s.tokenStore.ValidateAndConsume(token, accountID, reverseProxyID); err != nil {
if err := s.tokenStore.ValidateAndConsume(token, accountID, serviceID); err != nil {
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"service_id": serviceID,
"account_id": accountID,
}).WithError(err).Warn("CreateProxyPeer: token validation failed")
return &proto.CreateProxyPeerResponse{
Success: false,
@@ -597,8 +597,8 @@ func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.Cre
err := s.peersManager.CreateProxyPeer(ctx, accountID, key, cluster)
if err != nil {
log.WithFields(log.Fields{
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"service_id": serviceID,
"account_id": accountID,
}).WithError(err).Error("CreateProxyPeer: failed to create proxy peer")
return &proto.CreateProxyPeerResponse{
Success: false,
@@ -622,22 +622,22 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
// TODO: log
return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err)
}
// Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection.
proxies, err := s.reverseProxyManager.GetAccountReverseProxies(ctx, req.GetAccountId())
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "failed to get services from store: %v", err)
}
var found bool
for _, proxy := range proxies {
if proxy.Domain == redirectURL.Hostname() {
for _, service := range services {
if service.Domain == redirectURL.Hostname() {
found = true
break
}
}
if !found {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "reverse proxy not found in store")
return nil, status.Errorf(codes.FailedPrecondition, "service not found in store")
}
provider, err := oidc.NewProvider(ctx, s.oidcConfig.Issuer)
@@ -728,29 +728,29 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the proxy by domain to get its signing key
proxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
// Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
if err != nil {
return "", fmt.Errorf("get reverse proxies: %w", err)
return "", fmt.Errorf("get services: %w", err)
}
var proxy *reverseproxy.ReverseProxy
for _, p := range proxies {
if p.Domain == domain {
proxy = p
var service *reverseproxy.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
break
}
}
if proxy == nil {
return "", fmt.Errorf("reverse proxy not found for domain: %s", domain)
if service == nil {
return "", fmt.Errorf("service not found for domain: %s", domain)
}
if proxy.SessionPrivateKey == "" {
if service.SessionPrivateKey == "" {
return "", fmt.Errorf("no session key configured for domain: %s", domain)
}
return sessionkey.SignToken(
proxy.SessionPrivateKey,
service.SessionPrivateKey,
userID,
domain,
method,
@@ -758,8 +758,8 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
)
}
// ValidateUserGroupAccess checks if a user has access to a reverse proxy.
// It looks up the proxy within the user's account only, then optionally checks
// ValidateUserGroupAccess checks if a user has access to a service.
// It looks up the service within the user's account only, then optionally checks
// group membership if BearerAuth with DistributionGroups is configured.
func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain, userID string) error {
user, err := s.usersManager.GetUser(ctx, userID)
@@ -767,16 +767,16 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user not found: %s", userID)
}
proxy, err := s.getAccountProxyByDomain(ctx, user.AccountID, domain)
service, err := s.getAccountServiceByDomain(ctx, user.AccountID, domain)
if err != nil {
return err
}
if proxy.Auth.BearerAuth == nil || !proxy.Auth.BearerAuth.Enabled {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}
allowedGroups := proxy.Auth.BearerAuth.DistributionGroups
allowedGroups := service.Auth.BearerAuth.DistributionGroups
if len(allowedGroups) == 0 {
return nil
}
@@ -800,19 +800,19 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
}
func (s *ProxyServiceServer) getAccountProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) {
proxies, err := s.reverseProxyManager.GetAccountReverseProxies(ctx, accountID)
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account reverse proxies: %w", err)
return nil, fmt.Errorf("get account services: %w", err)
}
for _, proxy := range proxies {
if proxy.Domain == domain {
return proxy, nil
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("reverse proxy not found for domain %s in account %s", domain, accountID)
return nil, fmt.Errorf("service not found for domain %s in account %s", domain, accountID)
}
// ValidateSession validates a session token and checks if the user has access to the domain.
@@ -827,19 +827,19 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
proxy, err := s.getProxyByDomain(ctx, domain)
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Debug("ValidateSession: proxy not found")
}).Debug("ValidateSession: service not found")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "proxy_not_found",
DeniedReason: "service_not_found",
}, nil
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(proxy.SessionPublicKey)
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
@@ -847,7 +847,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}).Error("ValidateSession: decode public key")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "invalid_proxy_config",
DeniedReason: "invalid_service_config",
}, nil
}
@@ -876,12 +876,12 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
if user.AccountID != proxy.AccountID {
if user.AccountID != service.AccountID {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"user_account": user.AccountID,
"proxy_account": proxy.AccountID,
"domain": domain,
"user_id": userID,
"user_account": user.AccountID,
"service_account": service.AccountID,
}).Debug("ValidateSession: user account mismatch")
return &proto.ValidateSessionResponse{
Valid: false,
@@ -889,7 +889,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
if err := s.checkGroupAccess(proxy, user); err != nil {
if err := s.checkGroupAccess(service, user); err != nil {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
@@ -916,27 +916,27 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
func (s *ProxyServiceServer) getProxyByDomain(ctx context.Context, domain string) (*reverseproxy.ReverseProxy, error) {
proxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get reverse proxies: %w", err)
return nil, fmt.Errorf("get services: %w", err)
}
for _, proxy := range proxies {
if proxy.Domain == domain {
return proxy, nil
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("reverse proxy not found for domain: %s", domain)
return nil, fmt.Errorf("service not found for domain: %s", domain)
}
func (s *ProxyServiceServer) checkGroupAccess(proxy *reverseproxy.ReverseProxy, user *types.User) error {
if proxy.Auth.BearerAuth == nil || !proxy.Auth.BearerAuth.Enabled {
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}
allowedGroups := proxy.Auth.BearerAuth.DistributionGroups
allowedGroups := service.Auth.BearerAuth.DistributionGroups
if len(allowedGroups) == 0 {
return nil
}