mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Remove client secret from gRPC auth flow. The secret was originally included to support providers like Google Workspace that don't offer a proper PKCE flow, but this is no longer necessary with the embedded IdP. Deployments using such providers should migrate to the embedded IdP instead.
498 lines
15 KiB
Go
498 lines
15 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
"github.com/google/uuid"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
"github.com/netbirdio/netbird/client/ssh"
|
|
"github.com/netbirdio/netbird/client/system"
|
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
)
|
|
|
|
// Auth manages authentication operations with the management server
|
|
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
|
type Auth struct {
|
|
mutex sync.RWMutex
|
|
client *mgm.GrpcClient
|
|
config *profilemanager.Config
|
|
privateKey wgtypes.Key
|
|
mgmURL *url.URL
|
|
mgmTLSEnabled bool
|
|
}
|
|
|
|
// NewAuth creates a new Auth instance that manages authentication flows
|
|
// It establishes a connection to the management server that will be reused for all operations
|
|
// The connection is automatically recreated with backoff if it becomes disconnected
|
|
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
|
// Validate WireGuard private key
|
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Determine TLS setting based on URL scheme
|
|
mgmTLSEnabled := mgmURL.Scheme == "https"
|
|
|
|
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
if err != nil {
|
|
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
|
return nil, err
|
|
}
|
|
|
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
|
|
return &Auth{
|
|
client: mgmClient,
|
|
config: config,
|
|
privateKey: myPrivateKey,
|
|
mgmURL: mgmURL,
|
|
mgmTLSEnabled: mgmTLSEnabled,
|
|
}, nil
|
|
}
|
|
|
|
// Close closes the management client connection
|
|
func (a *Auth) Close() error {
|
|
a.mutex.Lock()
|
|
defer a.mutex.Unlock()
|
|
|
|
if a.client == nil {
|
|
return nil
|
|
}
|
|
return a.client.Close()
|
|
}
|
|
|
|
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
|
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
|
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
|
// Automatically retries with backoff and reconnection on connection errors.
|
|
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
|
var supportsSSO bool
|
|
|
|
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
|
// Try PKCE flow first
|
|
_, err := a.getPKCEFlow(client)
|
|
if err == nil {
|
|
supportsSSO = true
|
|
return nil
|
|
}
|
|
|
|
// Check if PKCE is not supported
|
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
|
// PKCE not supported, try Device flow
|
|
_, err = a.getDeviceFlow(client)
|
|
if err == nil {
|
|
supportsSSO = true
|
|
return nil
|
|
}
|
|
|
|
// Check if Device flow is also not supported
|
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
|
// Neither PKCE nor Device flow is supported
|
|
supportsSSO = false
|
|
return nil
|
|
}
|
|
|
|
// Device flow check returned an error other than NotFound/Unimplemented
|
|
return err
|
|
}
|
|
|
|
// PKCE flow check returned an error other than NotFound/Unimplemented
|
|
return err
|
|
})
|
|
|
|
return supportsSSO, err
|
|
}
|
|
|
|
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
|
// This avoids creating a new connection to the management server
|
|
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
|
var flow OAuthFlow
|
|
var err error
|
|
|
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
|
if forceDeviceAuth {
|
|
flow, err = a.getDeviceFlow(client)
|
|
return err
|
|
}
|
|
|
|
// Try PKCE flow first
|
|
flow, err = a.getPKCEFlow(client)
|
|
if err != nil {
|
|
// If PKCE not supported, try Device flow
|
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
|
flow, err = a.getDeviceFlow(client)
|
|
return err
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
|
|
return flow, err
|
|
}
|
|
|
|
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
|
// Automatically retries with backoff and reconnection on connection errors.
|
|
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
var needsLogin bool
|
|
|
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
|
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
|
if isLoginNeeded(err) {
|
|
needsLogin = true
|
|
return nil
|
|
}
|
|
needsLogin = false
|
|
return err
|
|
})
|
|
|
|
return needsLogin, err
|
|
}
|
|
|
|
// Login attempts to log in or register the client with the management server
|
|
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
|
// Automatically retries with backoff and reconnection on connection errors.
|
|
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
|
if err != nil {
|
|
return err, false
|
|
}
|
|
|
|
var isAuthError bool
|
|
|
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
|
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
|
if serverKey != nil && isRegistrationNeeded(err) {
|
|
log.Debugf("peer registration required")
|
|
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
|
if err != nil {
|
|
isAuthError = isPermissionDenied(err)
|
|
return err
|
|
}
|
|
} else if err != nil {
|
|
isAuthError = isPermissionDenied(err)
|
|
return err
|
|
}
|
|
|
|
isAuthError = false
|
|
return nil
|
|
})
|
|
|
|
return err, isAuthError
|
|
}
|
|
|
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
|
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
|
serverKey, err := client.GetServerPublicKey()
|
|
if err != nil {
|
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
|
if err != nil {
|
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
|
return nil, err
|
|
}
|
|
log.Errorf("failed to retrieve pkce flow: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
protoConfig := protoFlow.GetProviderConfig()
|
|
config := &PKCEAuthProviderConfig{
|
|
Audience: protoConfig.GetAudience(),
|
|
ClientID: protoConfig.GetClientID(),
|
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
|
Scope: protoConfig.GetScope(),
|
|
RedirectURLs: protoConfig.GetRedirectURLs(),
|
|
UseIDToken: protoConfig.GetUseIDToken(),
|
|
ClientCertPair: a.config.ClientCertKeyPair,
|
|
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
|
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
|
}
|
|
|
|
if err := validatePKCEConfig(config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
flow, err := NewPKCEAuthorizationFlow(*config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return flow, nil
|
|
}
|
|
|
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
|
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
|
serverKey, err := client.GetServerPublicKey()
|
|
if err != nil {
|
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
|
if err != nil {
|
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
|
return nil, err
|
|
}
|
|
log.Errorf("failed to retrieve device flow: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
protoConfig := protoFlow.GetProviderConfig()
|
|
config := &DeviceAuthProviderConfig{
|
|
Audience: protoConfig.GetAudience(),
|
|
ClientID: protoConfig.GetClientID(),
|
|
Domain: protoConfig.Domain,
|
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
|
Scope: protoConfig.GetScope(),
|
|
UseIDToken: protoConfig.GetUseIDToken(),
|
|
}
|
|
|
|
// Keep compatibility with older management versions
|
|
if config.Scope == "" {
|
|
config.Scope = "openid"
|
|
}
|
|
|
|
if err := validateDeviceAuthConfig(config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
flow, err := NewDeviceAuthorizationFlow(*config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return flow, nil
|
|
}
|
|
|
|
// doMgmLogin performs the actual login operation with the management service
|
|
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
|
serverKey, err := client.GetServerPublicKey()
|
|
if err != nil {
|
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
return nil, nil, err
|
|
}
|
|
|
|
sysInfo := system.GetInfo(ctx)
|
|
a.setSystemInfoFlags(sysInfo)
|
|
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
|
return serverKey, loginResp, err
|
|
}
|
|
|
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
|
// Otherwise tries to register with the provided setupKey via command line.
|
|
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
|
serverPublicKey, err := client.GetServerPublicKey()
|
|
if err != nil {
|
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
validSetupKey, err := uuid.Parse(setupKey)
|
|
if err != nil && jwtToken == "" {
|
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
|
}
|
|
|
|
log.Debugf("sending peer registration request to Management Service")
|
|
info := system.GetInfo(ctx)
|
|
a.setSystemInfoFlags(info)
|
|
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
|
if err != nil {
|
|
log.Errorf("failed registering peer %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
log.Infof("peer has been successfully registered on Management Service")
|
|
|
|
return loginResp, nil
|
|
}
|
|
|
|
// setSystemInfoFlags sets all configuration flags on the provided system info
|
|
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
|
info.SetFlags(
|
|
a.config.RosenpassEnabled,
|
|
a.config.RosenpassPermissive,
|
|
a.config.ServerSSHAllowed,
|
|
a.config.DisableClientRoutes,
|
|
a.config.DisableServerRoutes,
|
|
a.config.DisableDNS,
|
|
a.config.DisableFirewall,
|
|
a.config.BlockLANAccess,
|
|
a.config.BlockInbound,
|
|
a.config.LazyConnectionEnabled,
|
|
a.config.EnableSSHRoot,
|
|
a.config.EnableSSHSFTP,
|
|
a.config.EnableSSHLocalPortForwarding,
|
|
a.config.EnableSSHRemotePortForwarding,
|
|
a.config.DisableSSHAuth,
|
|
)
|
|
}
|
|
|
|
// reconnect closes the current connection and creates a new one
|
|
// It checks if the brokenClient is still the current client before reconnecting
|
|
// to avoid multiple threads reconnecting unnecessarily
|
|
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
|
a.mutex.Lock()
|
|
defer a.mutex.Unlock()
|
|
|
|
// Double-check: if client has already been replaced by another thread, skip reconnection
|
|
if a.client != brokenClient {
|
|
log.Debugf("client already reconnected by another thread, skipping")
|
|
return nil
|
|
}
|
|
|
|
// Create new connection FIRST, before closing the old one
|
|
// This ensures a.client is never nil, preventing panics in other threads
|
|
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
|
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
|
if err != nil {
|
|
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
|
// Keep the old client if reconnection fails
|
|
return err
|
|
}
|
|
|
|
// Close old connection AFTER new one is successfully created
|
|
oldClient := a.client
|
|
a.client = mgmClient
|
|
|
|
if oldClient != nil {
|
|
if err := oldClient.Close(); err != nil {
|
|
log.Debugf("error closing old connection: %v", err)
|
|
}
|
|
}
|
|
|
|
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
|
return nil
|
|
}
|
|
|
|
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
|
func isConnectionError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
s, ok := status.FromError(err)
|
|
if !ok {
|
|
return false
|
|
}
|
|
// These error codes indicate connection issues
|
|
return s.Code() == codes.Unavailable ||
|
|
s.Code() == codes.DeadlineExceeded ||
|
|
s.Code() == codes.Canceled ||
|
|
s.Code() == codes.Internal
|
|
}
|
|
|
|
// withRetry wraps an operation with exponential backoff retry logic
|
|
// It automatically reconnects on connection errors
|
|
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
|
backoffSettings := &backoff.ExponentialBackOff{
|
|
InitialInterval: 500 * time.Millisecond,
|
|
RandomizationFactor: 0.5,
|
|
Multiplier: 1.5,
|
|
MaxInterval: 10 * time.Second,
|
|
MaxElapsedTime: 2 * time.Minute,
|
|
Stop: backoff.Stop,
|
|
Clock: backoff.SystemClock,
|
|
}
|
|
backoffSettings.Reset()
|
|
|
|
return backoff.RetryNotify(
|
|
func() error {
|
|
// Capture the client BEFORE the operation to ensure we track the correct client
|
|
a.mutex.RLock()
|
|
currentClient := a.client
|
|
a.mutex.RUnlock()
|
|
|
|
if currentClient == nil {
|
|
return status.Errorf(codes.Unavailable, "client is not initialized")
|
|
}
|
|
|
|
// Execute operation with the captured client
|
|
err := operation(currentClient)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// If it's a connection error, attempt reconnection using the client that was actually used
|
|
if isConnectionError(err) {
|
|
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
|
|
|
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
|
log.Errorf("reconnection failed: %v", reconnectErr)
|
|
return reconnectErr
|
|
}
|
|
// Return the original error to trigger retry with the new connection
|
|
return err
|
|
}
|
|
|
|
// For authentication errors, don't retry
|
|
if isAuthenticationError(err) {
|
|
return backoff.Permanent(err)
|
|
}
|
|
|
|
return err
|
|
},
|
|
backoff.WithContext(backoffSettings, ctx),
|
|
func(err error, duration time.Duration) {
|
|
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
|
},
|
|
)
|
|
}
|
|
|
|
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
|
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
|
func isAuthenticationError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
s, ok := status.FromError(err)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
|
}
|
|
|
|
// isPermissionDenied checks if the error is a PermissionDenied error.
|
|
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
|
func isPermissionDenied(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
s, ok := status.FromError(err)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return s.Code() == codes.PermissionDenied
|
|
}
|
|
|
|
func isLoginNeeded(err error) bool {
|
|
return isAuthenticationError(err)
|
|
}
|
|
|
|
func isRegistrationNeeded(err error) bool {
|
|
return isPermissionDenied(err)
|
|
}
|