mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-08 17:59:56 +00:00
Merge branch 'main' into feature/client-metrics
# Conflicts: # client/internal/debug/debug.go # client/internal/engine.go # client/server/debug.go
This commit is contained in:
499
client/internal/auth/auth.go
Normal file
499
client/internal/auth/auth.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
mutex sync.RWMutex
|
||||
client *mgm.GrpcClient
|
||||
config *profilemanager.Config
|
||||
privateKey wgtypes.Key
|
||||
mgmURL *url.URL
|
||||
mgmTLSEnabled bool
|
||||
}
|
||||
|
||||
// NewAuth creates a new Auth instance that manages authentication flows
|
||||
// It establishes a connection to the management server that will be reused for all operations
|
||||
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||
// Validate WireGuard private key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine TLS setting based on URL scheme
|
||||
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
return &Auth{
|
||||
client: mgmClient,
|
||||
config: config,
|
||||
privateKey: myPrivateKey,
|
||||
mgmURL: mgmURL,
|
||||
mgmTLSEnabled: mgmTLSEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the management client connection
|
||||
func (a *Auth) Close() error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
if a.client == nil {
|
||||
return nil
|
||||
}
|
||||
return a.client.Close()
|
||||
}
|
||||
|
||||
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||
var supportsSSO bool
|
||||
|
||||
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
// Try PKCE flow first
|
||||
_, err := a.getPKCEFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if PKCE is not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// PKCE not supported, try Device flow
|
||||
_, err = a.getDeviceFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if Device flow is also not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// Neither PKCE nor Device flow is supported
|
||||
supportsSSO = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Device flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
}
|
||||
|
||||
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
})
|
||||
|
||||
return supportsSSO, err
|
||||
}
|
||||
|
||||
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||
// This avoids creating a new connection to the management server
|
||||
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||
var flow OAuthFlow
|
||||
var err error
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
if forceDeviceAuth {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
|
||||
// Try PKCE flow first
|
||||
flow, err = a.getPKCEFlow(client)
|
||||
if err != nil {
|
||||
// If PKCE not supported, try Device flow
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return flow, err
|
||||
}
|
||||
|
||||
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
needsLogin = false
|
||||
return err
|
||||
})
|
||||
|
||||
return needsLogin, err
|
||||
}
|
||||
|
||||
// Login attempts to log in or register the client with the management server
|
||||
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
|
||||
isAuthError = false
|
||||
return nil
|
||||
})
|
||||
|
||||
return err, isAuthError
|
||||
}
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
ClientCertPair: a.config.ClientCertKeyPair,
|
||||
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||
}
|
||||
|
||||
if err := validatePKCEConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
}
|
||||
|
||||
// Keep compatibility with older management versions
|
||||
if config.Scope == "" {
|
||||
config.Scope = "openid"
|
||||
}
|
||||
|
||||
if err := validateDeviceAuthConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
a.config.DisableFirewall,
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
)
|
||||
}
|
||||
|
||||
// reconnect closes the current connection and creates a new one
|
||||
// It checks if the brokenClient is still the current client before reconnecting
|
||||
// to avoid multiple threads reconnecting unnecessarily
|
||||
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||
if a.client != brokenClient {
|
||||
log.Debugf("client already reconnected by another thread, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new connection FIRST, before closing the old one
|
||||
// This ensures a.client is never nil, preventing panics in other threads
|
||||
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||
// Keep the old client if reconnection fails
|
||||
return err
|
||||
}
|
||||
|
||||
// Close old connection AFTER new one is successfully created
|
||||
oldClient := a.client
|
||||
a.client = mgmClient
|
||||
|
||||
if oldClient != nil {
|
||||
if err := oldClient.Close(); err != nil {
|
||||
log.Debugf("error closing old connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// These error codes indicate connection issues
|
||||
return s.Code() == codes.Unavailable ||
|
||||
s.Code() == codes.DeadlineExceeded ||
|
||||
s.Code() == codes.Canceled ||
|
||||
s.Code() == codes.Internal
|
||||
}
|
||||
|
||||
// withRetry wraps an operation with exponential backoff retry logic
|
||||
// It automatically reconnects on connection errors
|
||||
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||
backoffSettings := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.5,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 2 * time.Minute,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
backoffSettings.Reset()
|
||||
|
||||
return backoff.RetryNotify(
|
||||
func() error {
|
||||
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||
a.mutex.RLock()
|
||||
currentClient := a.client
|
||||
a.mutex.RUnlock()
|
||||
|
||||
if currentClient == nil {
|
||||
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||
}
|
||||
|
||||
// Execute operation with the captured client
|
||||
err := operation(currentClient)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||
if isConnectionError(err) {
|
||||
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||
|
||||
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||
return reconnectErr
|
||||
}
|
||||
// Return the original error to trigger retry with the new connection
|
||||
return err
|
||||
}
|
||||
|
||||
// For authentication errors, don't retry
|
||||
if isAuthenticationError(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
backoff.WithContext(backoffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||
func isAuthenticationError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||
func isPermissionDenied(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
return isAuthenticationError(err)
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
@@ -26,12 +25,56 @@ const (
|
||||
|
||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||
// for the Device Authorization Flow.
|
||||
type DeviceAuthorizationFlow struct {
|
||||
providerConfig internal.DeviceAuthProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
providerConfig DeviceAuthProviderConfig
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
|
||||
}
|
||||
|
||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||
return d.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the device authorization flow
|
||||
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||
d.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
form := url.Values{}
|
||||
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := d.requestToken(info)
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
err: testCase.inputReqError,
|
||||
}
|
||||
|
||||
deviceFlow := &DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
countResBody: testCase.inputCountResBody,
|
||||
}
|
||||
|
||||
deviceFlow := DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
if hint != "" {
|
||||
pkceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
return pkceFlowInfo, nil
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
case ok && s.Code() == codes.NotFound:
|
||||
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
if hint != "" {
|
||||
deviceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
return deviceFlowInfo, nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
@@ -35,17 +34,67 @@ const (
|
||||
defaultPKCETimeoutSeconds = 300
|
||||
)
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validatePKCEConfig validates PKCE provider configuration
|
||||
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||
// the Authorization Code Flow with PKCE.
|
||||
type PKCEAuthorizationFlow struct {
|
||||
providerConfig internal.PKCEAuthProviderConfig
|
||||
providerConfig PKCEAuthProviderConfig
|
||||
state string
|
||||
codeVerifier string
|
||||
oAuthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||
p.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
config := PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
config := PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -59,6 +59,7 @@ block.prof: Block profiling information.
|
||||
heap.prof: Heap profiling information (snapshot of memory allocations).
|
||||
allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
cpu.prof: CPU profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
|
||||
|
||||
@@ -231,6 +232,8 @@ type BundleGenerator struct {
|
||||
statusRecorder *peer.Status
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logPath string
|
||||
cpuProfile []byte
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
|
||||
anonymize bool
|
||||
@@ -251,6 +254,8 @@ type GeneratorDependencies struct {
|
||||
StatusRecorder *peer.Status
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogPath string
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
ClientMetrics MetricsExporter
|
||||
}
|
||||
|
||||
@@ -268,6 +273,8 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
statusRecorder: deps.StatusRecorder,
|
||||
syncResponse: deps.SyncResponse,
|
||||
logPath: deps.LogPath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
@@ -332,6 +339,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addCPUProfile(); err != nil {
|
||||
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
@@ -412,6 +423,10 @@ func (g *BundleGenerator) addStatus() error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
if g.refreshStatus != nil {
|
||||
g.refreshStatus()
|
||||
}
|
||||
|
||||
fullStatus := g.statusRecorder.GetFullStatus()
|
||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||
@@ -554,6 +569,19 @@ func (g *BundleGenerator) addProf() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCPUProfile() error {
|
||||
if len(g.cpuProfile) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(g.cpuProfile)
|
||||
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
|
||||
return fmt.Errorf("add CPU profile to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig DeviceAuthProviderConfig
|
||||
}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: DeviceAuthProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
// keep compatibility with older management versions
|
||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -81,7 +81,10 @@ func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
logger.Debug("received local resolver request with no question")
|
||||
|
||||
@@ -71,6 +71,11 @@ type upstreamResolverBase struct {
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
@@ -114,7 +119,10 @@ func (u *upstreamResolverBase) Stop() {
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
u.prepareRequest(r)
|
||||
|
||||
@@ -123,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if u.tryUpstreamServers(w, r, logger) {
|
||||
return
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
if !ok {
|
||||
u.writeErrorResponse(w, r, logger)
|
||||
}
|
||||
|
||||
u.writeErrorResponse(w, r, logger)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
@@ -136,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -149,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
}
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if u.queryUpstream(w, r, upstream, timeout, logger) {
|
||||
return true
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
}
|
||||
}
|
||||
return false
|
||||
return false, failures
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
@@ -171,31 +185,32 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
|
||||
return false
|
||||
return u.handleUpstreamError(err, upstream, startTime)
|
||||
}
|
||||
|
||||
if rm == nil || !rm.Response {
|
||||
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||
return false
|
||||
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
|
||||
return
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
|
||||
reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond))
|
||||
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
|
||||
timeoutMsg += " " + peerInfo
|
||||
reason += " " + peerInfo
|
||||
}
|
||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||
logger.Warn(timeoutMsg)
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
@@ -215,16 +230,34 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||
totalUpstreams := len(u.upstreamServers)
|
||||
failedCount := len(failures)
|
||||
failureSummary := formatFailures(failures)
|
||||
|
||||
if succeeded {
|
||||
logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
|
||||
} else {
|
||||
logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(r, dns.RcodeServerFailure)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||
logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func formatFailures(failures []upstreamFailure) string {
|
||||
parts := make([]string, 0, len(failures))
|
||||
for _, f := range failures {
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason))
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
@@ -468,7 +501,6 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
|
||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||
func FormatPeerStatus(peerState *peer.State) string {
|
||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||
|
||||
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@@ -9,6 +10,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -140,6 +143,23 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
|
||||
return c.r, c.rtt, c.err
|
||||
}
|
||||
|
||||
type mockUpstreamResponse struct {
|
||||
msg *dns.Msg
|
||||
err error
|
||||
}
|
||||
|
||||
type mockUpstreamResolverPerServer struct {
|
||||
responses map[string]mockUpstreamResponse
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
if r, ok := c.responses[upstream]; ok {
|
||||
return r.msg, c.rtt, r.err
|
||||
}
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolver{
|
||||
err: dns.ErrTime,
|
||||
@@ -191,3 +211,267 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
t.Errorf("should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
successAnswer := "192.0.2.100"
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
upstream1 mockUpstreamResponse
|
||||
upstream2 mockUpstreamResponse
|
||||
expectedRcode int
|
||||
expectAnswer bool
|
||||
expectTrySecond bool
|
||||
}{
|
||||
{
|
||||
name: "success on first upstream",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: false,
|
||||
},
|
||||
{
|
||||
name: "SERVFAIL from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "REFUSED from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "NXDOMAIN from first should NOT try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeNameError,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: false,
|
||||
},
|
||||
{
|
||||
name: "timeout from first should try second",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "no response from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: nil},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "both upstreams return SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "both upstreams timeout",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{err: timeoutErr},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first SERVFAIL then timeout",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{err: timeoutErr},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first timeout then SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first REFUSED then SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var queriedUpstreams []string
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
upstream1.String(): tc.upstream1,
|
||||
upstream2.String(): tc.upstream2,
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
trackingClient := &trackingMockClient{
|
||||
inner: mockClient,
|
||||
queriedUpstreams: &queriedUpstreams,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: trackingClient,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode")
|
||||
|
||||
if tc.expectAnswer {
|
||||
require.NotEmpty(t, responseMSG.Answer, "expected answer records")
|
||||
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||
}
|
||||
|
||||
if tc.expectTrySecond {
|
||||
assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams")
|
||||
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
|
||||
assert.Equal(t, upstream2.String(), queriedUpstreams[1])
|
||||
} else {
|
||||
assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream")
|
||||
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type trackingMockClient struct {
|
||||
inner *mockUpstreamResolverPerServer
|
||||
queriedUpstreams *[]string
|
||||
}
|
||||
|
||||
func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
*t.queriedUpstreams = append(*t.queriedUpstreams, upstream)
|
||||
return t.inner.exchange(ctx, upstream, r)
|
||||
}
|
||||
|
||||
func buildMockResponse(rcode int, answer string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.Response = true
|
||||
m.Rcode = rcode
|
||||
|
||||
if rcode == dns.RcodeSuccess && answer != "" {
|
||||
m.Answer = []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "example.com.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: net.ParseIP(answer),
|
||||
},
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
upstream := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamServers: []netip.AddrPort{upstream},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
failures []upstreamFailure
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty slice",
|
||||
failures: []upstreamFailure{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single failure",
|
||||
failures: []upstreamFailure{
|
||||
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
|
||||
},
|
||||
expected: "8.8.8.8:53=SERVFAIL",
|
||||
},
|
||||
{
|
||||
name: "multiple failures",
|
||||
failures: []upstreamFailure{
|
||||
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
|
||||
{upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"},
|
||||
},
|
||||
expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := formatFailures(tc.failures)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -516,6 +516,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
// Set up notrack rules immediately after proxy is listening to prevent
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -628,6 +632,23 @@ func (e *Engine) initFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
|
||||
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
|
||||
func (e *Engine) setupWGProxyNoTrack() {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
|
||||
proxyPort := e.wgInterface.GetProxyPort()
|
||||
if proxyPort == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
|
||||
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) blockLanAccess() {
|
||||
if e.config.BlockInbound {
|
||||
// no need to set up extra deny rules if inbound is already blocked in general
|
||||
@@ -1062,6 +1083,9 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
ClientMetrics: e.clientMetrics,
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
},
|
||||
}
|
||||
|
||||
bundleJobParams := debug.BundleConfig{
|
||||
@@ -1654,6 +1678,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
func (e *Engine) close() {
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
@@ -1845,7 +1870,7 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1859,23 +1884,8 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
stuns := slices.Clone(e.STUNs)
|
||||
turns := slices.Clone(e.TURNs)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
stats, err := e.wgInterface.GetStats()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get wireguard stats: %v", err)
|
||||
e.syncMsgMux.Unlock()
|
||||
return false
|
||||
}
|
||||
for _, key := range e.peerStore.PeersPubKey() {
|
||||
// wgStats could be zero value, in which case we just reset the stats
|
||||
wgStats, ok := stats[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
}
|
||||
if err := e.statusRecorder.RefreshWireGuardStats(); err != nil {
|
||||
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -107,6 +107,7 @@ type MockWGIface struct {
|
||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetProxyPortFunc func() uint16
|
||||
GetNetFunc func() *netstack.Net
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
@@ -203,6 +204,13 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxyPort() uint16 {
|
||||
if m.GetProxyPortFunc != nil {
|
||||
return m.GetProxyPortFunc()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemoveEndpointAddress(key string) error
|
||||
RemovePeer(peerKey string) error
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// IsLoginRequired check that the server is support SSO or not
|
||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
||||
mgmURL := config.ManagementURL
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if isLoginNeeded(err) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Login or register the client
|
||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
return mgmClient, err
|
||||
}
|
||||
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
info.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1145,6 +1145,38 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
return d.wgIface.FullStats()
|
||||
}
|
||||
|
||||
// RefreshWireGuardStats fetches fresh WireGuard statistics from the interface
|
||||
// and updates the cached peer states. This ensures accurate handshake times and
|
||||
// transfer statistics in status reports without running full health probes.
|
||||
func (d *Status) RefreshWireGuardStats() error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if d.wgIface == nil {
|
||||
return nil // silently skip if interface not set
|
||||
}
|
||||
|
||||
stats, err := d.wgIface.FullStats()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get wireguard stats: %w", err)
|
||||
}
|
||||
|
||||
// Update each peer's WireGuard statistics
|
||||
for _, peerStats := range stats.Peers {
|
||||
peerState, ok := d.peers[peerStats.PublicKey]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
peerState.LastWireguardHandshake = peerStats.LastHandshake
|
||||
peerState.BytesRx = peerStats.RxBytes
|
||||
peerState.BytesTx = peerStats.TxBytes
|
||||
d.peers[peerStats.PublicKey] = peerState
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type EventQueue struct {
|
||||
maxSize int
|
||||
events []*proto.SystemEvent
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||
LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())),
|
||||
RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())),
|
||||
Relayed: isRelayed(pair),
|
||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||
}
|
||||
@@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
// wait local endpoint configuration
|
||||
time.Sleep(time.Second)
|
||||
addrString := pair.Remote.Address()
|
||||
parsed, err := netip.ParseAddr(addrString)
|
||||
if (err == nil) && (parsed.Is6()) {
|
||||
addrString = fmt.Sprintf("[%s]", addrString)
|
||||
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
|
||||
}
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
|
||||
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort)))
|
||||
if err != nil {
|
||||
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
|
||||
return
|
||||
@@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
||||
sessionID := w.SessionID()
|
||||
stats := agent.GetCandidatePairsStats()
|
||||
localCandidates, _ := agent.GetLocalCandidates()
|
||||
remoteCandidates, _ := agent.GetRemoteCandidates()
|
||||
|
||||
localMap := make(map[string]ice.Candidate)
|
||||
for _, c := range localCandidates {
|
||||
localMap[c.ID()] = c
|
||||
}
|
||||
remoteMap := make(map[string]ice.Candidate)
|
||||
for _, c := range remoteCandidates {
|
||||
remoteMap[c.ID()] = c
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.State == ice.CandidatePairStateSucceeded {
|
||||
local, lok := localMap[stat.LocalCandidateID]
|
||||
remote, rok := remoteMap[stat.RemoteCandidateID]
|
||||
if !lok || !rok {
|
||||
continue
|
||||
}
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||
sessionID,
|
||||
local.NetworkType(), local.Type(), local.Address(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||
stat.CurrentRoundTripTime*1000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
|
||||
return func(state ice.ConnectionState) {
|
||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||
switch state {
|
||||
case ice.ConnectionStateConnected:
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
w.logSuccessfulPaths(agent)
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||
type PKCEAuthorizationFlow struct {
|
||||
ProviderConfig PKCEAuthProviderConfig
|
||||
}
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
authFlow := PKCEAuthorizationFlow{
|
||||
ProviderConfig: PKCEAuthProviderConfig{
|
||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
ClientCertPair: clientCert,
|
||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
||||
},
|
||||
}
|
||||
|
||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return authFlow, nil
|
||||
}
|
||||
|
||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLog = slog.LevelInfo
|
||||
defaultLogLevelVar = "NB_ROSENPASS_LOG_LEVEL"
|
||||
)
|
||||
|
||||
func hashRosenpassKey(key []byte) string {
|
||||
hasher := sha256.New()
|
||||
hasher.Write(key)
|
||||
@@ -45,7 +50,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
||||
}
|
||||
|
||||
rpKeyHash := hashRosenpassKey(public)
|
||||
log.Debugf("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
||||
}
|
||||
|
||||
@@ -101,7 +106,7 @@ func (m *Manager) removePeer(wireGuardPubKey string) error {
|
||||
|
||||
func (m *Manager) generateConfig() (rp.Config, error) {
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
Level: getLogLevel(),
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
|
||||
cfg := rp.Config{Logger: logger}
|
||||
@@ -133,6 +138,26 @@ func (m *Manager) generateConfig() (rp.Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getLogLevel() slog.Level {
|
||||
level, ok := os.LookupEnv(defaultLogLevelVar)
|
||||
if !ok {
|
||||
return defaultLog
|
||||
}
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return slog.LevelDebug
|
||||
case "info":
|
||||
return slog.LevelInfo
|
||||
case "warn":
|
||||
return slog.LevelWarn
|
||||
case "error":
|
||||
return slog.LevelError
|
||||
default:
|
||||
log.Warnf("unknown log level: %s. Using default %s", level, defaultLog.String())
|
||||
return defaultLog
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) OnDisconnected(peerKey string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user