mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
501 lines
16 KiB
Go
501 lines
16 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
"github.com/google/uuid"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
"github.com/netbirdio/netbird/client/ssh"
|
|
"github.com/netbirdio/netbird/client/system"
|
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
)
|
|
|
|
// Auth manages authentication operations with the management server
|
|
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
|
type Auth struct {
|
|
mutex sync.RWMutex
|
|
client *mgm.GrpcClient
|
|
config *profilemanager.Config
|
|
privateKey wgtypes.Key
|
|
mgmURL *url.URL
|
|
mgmTLSEnabled bool
|
|
}
|
|
|
|
// NewAuth creates a new Auth instance that manages authentication flows
|
|
// It establishes a connection to the management server that will be reused for all operations
|
|
// The connection is automatically recreated with backoff if it becomes disconnected
|
|
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
|
// Validate WireGuard private key
|
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Determine TLS setting based on URL scheme
|
|
mgmTLSEnabled := mgmURL.Scheme == "https"
|
|
|
|
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
if err != nil {
|
|
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
|
return nil, err
|
|
}
|
|
|
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
|
|
return &Auth{
|
|
client: mgmClient,
|
|
config: config,
|
|
privateKey: myPrivateKey,
|
|
mgmURL: mgmURL,
|
|
mgmTLSEnabled: mgmTLSEnabled,
|
|
}, nil
|
|
}
|
|
|
|
// Close closes the management client connection
|
|
func (a *Auth) Close() error {
|
|
a.mutex.Lock()
|
|
defer a.mutex.Unlock()
|
|
|
|
if a.client == nil {
|
|
return nil
|
|
}
|
|
return a.client.Close()
|
|
}
|
|
|
|
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
|
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
|
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
|
// Automatically retries with backoff and reconnection on connection errors.
|
|
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
|
var supportsSSO bool
|
|
|
|
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
|
// Try PKCE flow first
|
|
_, err := a.getPKCEFlow(client)
|
|
if err == nil {
|
|
supportsSSO = true
|
|
return nil
|
|
}
|
|
|
|
// Check if PKCE is not supported
|
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
|
// PKCE not supported, try Device flow
|
|
_, err = a.getDeviceFlow(client)
|
|
if err == nil {
|
|
supportsSSO = true
|
|
return nil
|
|
}
|
|
|
|
// Check if Device flow is also not supported
|
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
|
// Neither PKCE nor Device flow is supported
|
|
supportsSSO = false
|
|
return nil
|
|
}
|
|
|
|
// Device flow check returned an error other than NotFound/Unimplemented
|
|
return err
|
|
}
|
|
|
|
// PKCE flow check returned an error other than NotFound/Unimplemented
|
|
return err
|
|
})
|
|
|
|
return supportsSSO, err
|
|
}
|
|
|
|
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
|
// This avoids creating a new connection to the management server
|
|
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
|
var flow OAuthFlow
|
|
var err error
|
|
|
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
|
if forceDeviceAuth {
|
|
flow, err = a.getDeviceFlow(client)
|
|
return err
|
|
}
|
|
|
|
// Try PKCE flow first
|
|
flow, err = a.getPKCEFlow(client)
|
|
if err != nil {
|
|
// If PKCE not supported, try Device flow
|
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
|
flow, err = a.getDeviceFlow(client)
|
|
return err
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
|
|
return flow, err
|
|
}
|
|
|
|
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
|
// Automatically retries with backoff and reconnection on connection errors.
|
|
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
var needsLogin bool
|
|
|
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
|
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
|
if isLoginNeeded(err) {
|
|
needsLogin = true
|
|
return nil
|
|
}
|
|
needsLogin = false
|
|
return err
|
|
})
|
|
|
|
return needsLogin, err
|
|
}
|
|
|
|
// Login attempts to log in or register the client with the management server
|
|
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
|
// Automatically retries with backoff and reconnection on connection errors.
|
|
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
|
if err != nil {
|
|
return err, false
|
|
}
|
|
|
|
var isAuthError bool
|
|
|
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
|
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
|
if serverKey != nil && isRegistrationNeeded(err) {
|
|
log.Debugf("peer registration required")
|
|
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
|
if err != nil {
|
|
isAuthError = isPermissionDenied(err)
|
|
return err
|
|
}
|
|
} else if err != nil {
|
|
isAuthError = isPermissionDenied(err)
|
|
return err
|
|
}
|
|
|
|
isAuthError = false
|
|
return nil
|
|
})
|
|
|
|
return err, isAuthError
|
|
}
|
|
|
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
|
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
|
serverKey, err := client.GetServerPublicKey()
|
|
if err != nil {
|
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
|
if err != nil {
|
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
|
return nil, err
|
|
}
|
|
log.Errorf("failed to retrieve pkce flow: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
protoConfig := protoFlow.GetProviderConfig()
|
|
config := &PKCEAuthProviderConfig{
|
|
Audience: protoConfig.GetAudience(),
|
|
ClientID: protoConfig.GetClientID(),
|
|
ClientSecret: protoConfig.GetClientSecret(),
|
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
|
Scope: protoConfig.GetScope(),
|
|
RedirectURLs: protoConfig.GetRedirectURLs(),
|
|
UseIDToken: protoConfig.GetUseIDToken(),
|
|
ClientCertPair: a.config.ClientCertKeyPair,
|
|
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
|
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
|
}
|
|
|
|
if err := validatePKCEConfig(config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
flow, err := NewPKCEAuthorizationFlow(*config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return flow, nil
|
|
}
|
|
|
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
|
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
|
serverKey, err := client.GetServerPublicKey()
|
|
if err != nil {
|
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
|
if err != nil {
|
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
|
return nil, err
|
|
}
|
|
log.Errorf("failed to retrieve device flow: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
protoConfig := protoFlow.GetProviderConfig()
|
|
config := &DeviceAuthProviderConfig{
|
|
Audience: protoConfig.GetAudience(),
|
|
ClientID: protoConfig.GetClientID(),
|
|
ClientSecret: protoConfig.GetClientSecret(),
|
|
Domain: protoConfig.Domain,
|
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
|
Scope: protoConfig.GetScope(),
|
|
UseIDToken: protoConfig.GetUseIDToken(),
|
|
}
|
|
|
|
// Keep compatibility with older management versions
|
|
if config.Scope == "" {
|
|
config.Scope = "openid"
|
|
}
|
|
|
|
if err := validateDeviceAuthConfig(config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
flow, err := NewDeviceAuthorizationFlow(*config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return flow, nil
|
|
}
|
|
|
|
// doMgmLogin performs the actual login operation with the management service
|
|
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
|
serverKey, err := client.GetServerPublicKey()
|
|
if err != nil {
|
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
return nil, nil, err
|
|
}
|
|
|
|
sysInfo := system.GetInfo(ctx)
|
|
a.setSystemInfoFlags(sysInfo)
|
|
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
|
return serverKey, loginResp, err
|
|
}
|
|
|
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
|
// Otherwise tries to register with the provided setupKey via command line.
|
|
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
|
serverPublicKey, err := client.GetServerPublicKey()
|
|
if err != nil {
|
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
validSetupKey, err := uuid.Parse(setupKey)
|
|
if err != nil && jwtToken == "" {
|
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
|
}
|
|
|
|
log.Debugf("sending peer registration request to Management Service")
|
|
info := system.GetInfo(ctx)
|
|
a.setSystemInfoFlags(info)
|
|
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
|
if err != nil {
|
|
log.Errorf("failed registering peer %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
log.Infof("peer has been successfully registered on Management Service")
|
|
|
|
return loginResp, nil
|
|
}
|
|
|
|
// setSystemInfoFlags sets all configuration flags on the provided system info
|
|
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
|
info.SetFlags(
|
|
a.config.RosenpassEnabled,
|
|
a.config.RosenpassPermissive,
|
|
a.config.ServerSSHAllowed,
|
|
a.config.DisableClientRoutes,
|
|
a.config.DisableServerRoutes,
|
|
a.config.DisableDNS,
|
|
a.config.DisableFirewall,
|
|
a.config.BlockLANAccess,
|
|
a.config.BlockInbound,
|
|
a.config.DisableIPv6,
|
|
a.config.LazyConnectionEnabled,
|
|
a.config.EnableSSHRoot,
|
|
a.config.EnableSSHSFTP,
|
|
a.config.EnableSSHLocalPortForwarding,
|
|
a.config.EnableSSHRemotePortForwarding,
|
|
a.config.DisableSSHAuth,
|
|
)
|
|
}
|
|
|
|
// reconnect closes the current connection and creates a new one
|
|
// It checks if the brokenClient is still the current client before reconnecting
|
|
// to avoid multiple threads reconnecting unnecessarily
|
|
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
|
a.mutex.Lock()
|
|
defer a.mutex.Unlock()
|
|
|
|
// Double-check: if client has already been replaced by another thread, skip reconnection
|
|
if a.client != brokenClient {
|
|
log.Debugf("client already reconnected by another thread, skipping")
|
|
return nil
|
|
}
|
|
|
|
// Create new connection FIRST, before closing the old one
|
|
// This ensures a.client is never nil, preventing panics in other threads
|
|
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
|
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
|
if err != nil {
|
|
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
|
// Keep the old client if reconnection fails
|
|
return err
|
|
}
|
|
|
|
// Close old connection AFTER new one is successfully created
|
|
oldClient := a.client
|
|
a.client = mgmClient
|
|
|
|
if oldClient != nil {
|
|
if err := oldClient.Close(); err != nil {
|
|
log.Debugf("error closing old connection: %v", err)
|
|
}
|
|
}
|
|
|
|
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
|
return nil
|
|
}
|
|
|
|
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
|
func isConnectionError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
s, ok := status.FromError(err)
|
|
if !ok {
|
|
return false
|
|
}
|
|
// These error codes indicate connection issues
|
|
return s.Code() == codes.Unavailable ||
|
|
s.Code() == codes.DeadlineExceeded ||
|
|
s.Code() == codes.Canceled ||
|
|
s.Code() == codes.Internal
|
|
}
|
|
|
|
// withRetry wraps an operation with exponential backoff retry logic
|
|
// It automatically reconnects on connection errors
|
|
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
|
backoffSettings := &backoff.ExponentialBackOff{
|
|
InitialInterval: 500 * time.Millisecond,
|
|
RandomizationFactor: 0.5,
|
|
Multiplier: 1.5,
|
|
MaxInterval: 10 * time.Second,
|
|
MaxElapsedTime: 2 * time.Minute,
|
|
Stop: backoff.Stop,
|
|
Clock: backoff.SystemClock,
|
|
}
|
|
backoffSettings.Reset()
|
|
|
|
return backoff.RetryNotify(
|
|
func() error {
|
|
// Capture the client BEFORE the operation to ensure we track the correct client
|
|
a.mutex.RLock()
|
|
currentClient := a.client
|
|
a.mutex.RUnlock()
|
|
|
|
if currentClient == nil {
|
|
return status.Errorf(codes.Unavailable, "client is not initialized")
|
|
}
|
|
|
|
// Execute operation with the captured client
|
|
err := operation(currentClient)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// If it's a connection error, attempt reconnection using the client that was actually used
|
|
if isConnectionError(err) {
|
|
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
|
|
|
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
|
log.Errorf("reconnection failed: %v", reconnectErr)
|
|
return reconnectErr
|
|
}
|
|
// Return the original error to trigger retry with the new connection
|
|
return err
|
|
}
|
|
|
|
// For authentication errors, don't retry
|
|
if isAuthenticationError(err) {
|
|
return backoff.Permanent(err)
|
|
}
|
|
|
|
return err
|
|
},
|
|
backoff.WithContext(backoffSettings, ctx),
|
|
func(err error, duration time.Duration) {
|
|
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
|
},
|
|
)
|
|
}
|
|
|
|
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
|
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
|
func isAuthenticationError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
s, ok := status.FromError(err)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
|
}
|
|
|
|
// isPermissionDenied checks if the error is a PermissionDenied error.
|
|
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
|
func isPermissionDenied(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
s, ok := status.FromError(err)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return s.Code() == codes.PermissionDenied
|
|
}
|
|
|
|
func isLoginNeeded(err error) bool {
|
|
return isAuthenticationError(err)
|
|
}
|
|
|
|
func isRegistrationNeeded(err error) bool {
|
|
return isPermissionDenied(err)
|
|
}
|