refactor: add ValidateSession gRPC and streamline test setup

- Add ValidateSession gRPC method for proxy-side user validation
- Move group access validation from REST callback to gRPC layer
- Capture user info in access logs via CapturedData mutable pointer
- Create validate_session_test.go for gRPC validation tests
- Simplify auth_callback_integration_test.go to create accounts
  programmatically instead of using SQL file
- SQL test data file now only used by validate_session_test.go
This commit is contained in:
mlsmaycon
2026-02-10 20:31:03 +01:00
parent 0cb02bd906
commit eea6120cd0
15 changed files with 955 additions and 238 deletions

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
proxyauth "github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -795,3 +796,143 @@ func (s *ProxyServiceServer) getAccountProxyByDomain(ctx context.Context, accoun
return nil, fmt.Errorf("reverse proxy not found for domain %s in account %s", domain, accountID)
}
// ValidateSession validates a session token and checks if the user has access to the domain.
func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.ValidateSessionRequest) (*proto.ValidateSessionResponse, error) {
domain := req.GetDomain()
sessionToken := req.GetSessionToken()
if domain == "" || sessionToken == "" {
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "missing domain or session_token",
}, nil
}
proxy, err := s.getProxyByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Debug("ValidateSession: proxy not found")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "proxy_not_found",
}, nil
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(proxy.SessionPublicKey)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Error("ValidateSession: decode public key")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "invalid_proxy_config",
}, nil
}
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Debug("ValidateSession: invalid session token")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "invalid_token",
}, nil
}
user, err := s.usersManager.GetUser(ctx, userID)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: user not found")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "user_not_found",
}, nil
}
if user.AccountID != proxy.AccountID {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"user_account": user.AccountID,
"proxy_account": proxy.AccountID,
}).Debug("ValidateSession: user account mismatch")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "account_mismatch",
}, nil
}
if err := s.checkGroupAccess(proxy, user); err != nil {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: access denied")
return &proto.ValidateSessionResponse{
Valid: false,
UserId: user.Id,
UserEmail: user.Email,
DeniedReason: "not_in_group",
}, nil
}
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"email": user.Email,
}).Debug("ValidateSession: access granted")
return &proto.ValidateSessionResponse{
Valid: true,
UserId: user.Id,
UserEmail: user.Email,
}, nil
}
func (s *ProxyServiceServer) getProxyByDomain(ctx context.Context, domain string) (*reverseproxy.ReverseProxy, error) {
proxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
if err != nil {
return nil, fmt.Errorf("get reverse proxies: %w", err)
}
for _, proxy := range proxies {
if proxy.Domain == domain {
return proxy, nil
}
}
return nil, fmt.Errorf("reverse proxy not found for domain: %s", domain)
}
func (s *ProxyServiceServer) checkGroupAccess(proxy *reverseproxy.ReverseProxy, user *types.User) error {
if proxy.Auth.BearerAuth == nil || !proxy.Auth.BearerAuth.Enabled {
return nil
}
allowedGroups := proxy.Auth.BearerAuth.DistributionGroups
if len(allowedGroups) == 0 {
return nil
}
allowedSet := make(map[string]bool, len(allowedGroups))
for _, groupID := range allowedGroups {
allowedSet[groupID] = true
}
for _, groupID := range user.AutoGroups {
if allowedSet[groupID] {
return nil
}
}
return fmt.Errorf("user not in allowed groups")
}